@@ -51,6 +51,9 @@ LinearSolve.@concrete mutable struct DualLinearCache{DT}
5151 primal_u_cache
5252 primal_b_cache
5353
54+ # Cache validity flag for RHS precalculation optimization
55+ rhs_cache_valid
56+
5457 dual_A
5558 dual_b
5659 dual_u
9699function xp_linsolve_rhs! (uu, ∂_A:: Union{<:Partials, <:AbstractArray{<:Partials}} ,
97100 ∂_b:: Union{<:Partials, <:AbstractArray{<:Partials}} , cache:: DualLinearCache )
98101
99- # Update cached partials lists
100- update_partials_list! (∂_A, cache. partials_A_list)
101- update_partials_list! (∂_b, cache. partials_b_list)
102+ # Update cached partials lists if cache is invalid
103+ if ! cache. rhs_cache_valid
104+ update_partials_list! (∂_A, cache. partials_A_list)
105+ update_partials_list! (∂_b, cache. partials_b_list)
106+ cache. rhs_cache_valid = true
107+ end
102108
103109 A_list = cache. partials_A_list
104110 b_list = cache. partials_b_list
105111
106- # Compute rhs = b - A*uu using five-argument mul!
112+ # Compute rhs = b - A*uu using precalculated b_list and five-argument mul!
107113 for i in eachindex (b_list)
108114 cache. rhs_list[i] .= b_list[i]
109115 mul! (cache. rhs_list[i], A_list[i], uu, - 1 , 1 )
@@ -116,8 +122,12 @@ function xp_linsolve_rhs!(
116122 uu, ∂_A:: Union{<:Partials, <:AbstractArray{<:Partials}} ,
117123 ∂_b:: Nothing , cache:: DualLinearCache )
118124
119- # Update cached partials list for A
120- update_partials_list! (∂_A, cache. partials_A_list)
125+ # Update cached partials list for A if cache is invalid
126+ if ! cache. rhs_cache_valid
127+ update_partials_list! (∂_A, cache. partials_A_list)
128+ cache. rhs_cache_valid = true
129+ end
130+
121131 A_list = cache. partials_A_list
122132
123133 # Compute rhs = -A*uu using five-argument mul!
@@ -132,11 +142,15 @@ function xp_linsolve_rhs!(
132142 uu, ∂_A:: Nothing , ∂_b:: Union{<:Partials, <:AbstractArray{<:Partials}} ,
133143 cache:: DualLinearCache )
134144
135- # Update cached partials list for b
136- update_partials_list! (∂_b, cache. partials_b_list)
145+ # Update cached partials list for b if cache is invalid
146+ if ! cache. rhs_cache_valid
147+ update_partials_list! (∂_b, cache. partials_b_list)
148+ cache. rhs_cache_valid = true
149+ end
150+
137151 b_list = cache. partials_b_list
138152
139- # Copy b_list to rhs_list
153+ # Copy precalculated b_list to rhs_list (no A*uu computation needed)
140154 for i in eachindex (b_list)
141155 cache. rhs_list[i] .= b_list[i]
142156 end
@@ -247,6 +261,7 @@ function __dual_init(
247261 similar (new_b),
248262 similar (new_b),
249263 similar (new_b),
264+ true , # Cache is initially valid
250265 A,
251266 b,
252267 zeros (dual_type, length (b))
@@ -284,13 +299,15 @@ function Base.setproperty!(dc::DualLinearCache, sym::Symbol, val)
284299 setproperty! (dc. linear_cache, sym, val)
285300 end
286301
287- # Update the partials if setting A or b
302+ # Update the partials and invalidate cache if setting A or b
288303 if sym === :A
289304 setfield! (dc, :dual_A , val)
290305 setfield! (dc, :partials_A , partial_vals (val))
306+ setfield! (dc, :rhs_cache_valid , false ) # Invalidate cache
291307 elseif sym === :b
292308 setfield! (dc, :dual_b , val)
293309 setfield! (dc, :partials_b , partial_vals (val))
310+ setfield! (dc, :rhs_cache_valid , false ) # Invalidate cache
294311 elseif sym === :u
295312 setfield! (dc, :dual_u , val)
296313 setfield! (dc, :partials_u , partial_vals (val))
0 commit comments