Skip to content

Commit c0299d5

Browse files
committed
only update partials lists when needed
1 parent bf3f06d commit c0299d5

File tree

1 file changed

+27
-10
lines changed

1 file changed

+27
-10
lines changed

ext/LinearSolveForwardDiffExt.jl

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@ LinearSolve.@concrete mutable struct DualLinearCache{DT}
5151
primal_u_cache
5252
primal_b_cache
5353

54+
# Cache validity flag for RHS precalculation optimization
55+
rhs_cache_valid
56+
5457
dual_A
5558
dual_b
5659
dual_u
@@ -96,14 +99,17 @@ end
9699
function xp_linsolve_rhs!(uu, ∂_A::Union{<:Partials, <:AbstractArray{<:Partials}},
97100
∂_b::Union{<:Partials, <:AbstractArray{<:Partials}}, cache::DualLinearCache)
98101

99-
# Update cached partials lists
100-
update_partials_list!(∂_A, cache.partials_A_list)
101-
update_partials_list!(∂_b, cache.partials_b_list)
102+
# Update cached partials lists if cache is invalid
103+
if !cache.rhs_cache_valid
104+
update_partials_list!(∂_A, cache.partials_A_list)
105+
update_partials_list!(∂_b, cache.partials_b_list)
106+
cache.rhs_cache_valid = true
107+
end
102108

103109
A_list = cache.partials_A_list
104110
b_list = cache.partials_b_list
105111

106-
# Compute rhs = b - A*uu using five-argument mul!
112+
# Compute rhs = b - A*uu using precalculated b_list and five-argument mul!
107113
for i in eachindex(b_list)
108114
cache.rhs_list[i] .= b_list[i]
109115
mul!(cache.rhs_list[i], A_list[i], uu, -1, 1)
@@ -116,8 +122,12 @@ function xp_linsolve_rhs!(
116122
uu, ∂_A::Union{<:Partials, <:AbstractArray{<:Partials}},
117123
∂_b::Nothing, cache::DualLinearCache)
118124

119-
# Update cached partials list for A
120-
update_partials_list!(∂_A, cache.partials_A_list)
125+
# Update cached partials list for A if cache is invalid
126+
if !cache.rhs_cache_valid
127+
update_partials_list!(∂_A, cache.partials_A_list)
128+
cache.rhs_cache_valid = true
129+
end
130+
121131
A_list = cache.partials_A_list
122132

123133
# Compute rhs = -A*uu using five-argument mul!
@@ -132,11 +142,15 @@ function xp_linsolve_rhs!(
132142
uu, ∂_A::Nothing, ∂_b::Union{<:Partials, <:AbstractArray{<:Partials}},
133143
cache::DualLinearCache)
134144

135-
# Update cached partials list for b
136-
update_partials_list!(∂_b, cache.partials_b_list)
145+
# Update cached partials list for b if cache is invalid
146+
if !cache.rhs_cache_valid
147+
update_partials_list!(∂_b, cache.partials_b_list)
148+
cache.rhs_cache_valid = true
149+
end
150+
137151
b_list = cache.partials_b_list
138152

139-
# Copy b_list to rhs_list
153+
# Copy precalculated b_list to rhs_list (no A*uu computation needed)
140154
for i in eachindex(b_list)
141155
cache.rhs_list[i] .= b_list[i]
142156
end
@@ -247,6 +261,7 @@ function __dual_init(
247261
similar(new_b),
248262
similar(new_b),
249263
similar(new_b),
264+
true, # Cache is initially valid
250265
A,
251266
b,
252267
zeros(dual_type, length(b))
@@ -284,13 +299,15 @@ function Base.setproperty!(dc::DualLinearCache, sym::Symbol, val)
284299
setproperty!(dc.linear_cache, sym, val)
285300
end
286301

287-
# Update the partials if setting A or b
302+
# Update the partials and invalidate cache if setting A or b
288303
if sym === :A
289304
setfield!(dc, :dual_A, val)
290305
setfield!(dc, :partials_A, partial_vals(val))
306+
setfield!(dc, :rhs_cache_valid, false) # Invalidate cache
291307
elseif sym === :b
292308
setfield!(dc, :dual_b, val)
293309
setfield!(dc, :partials_b, partial_vals(val))
310+
setfield!(dc, :rhs_cache_valid, false) # Invalidate cache
294311
elseif sym === :u
295312
setfield!(dc, :dual_u, val)
296313
setfield!(dc, :partials_u, partial_vals(val))

0 commit comments

Comments
 (0)