@@ -35,8 +35,7 @@ const DualAbstractLinearProblem = Union{
35
35
36
36
LinearSolve. @concrete mutable struct DualLinearCache
37
37
linear_cache
38
- prob
39
- alg
38
+ dual_type
40
39
dual_u0
41
40
partials_A
42
41
partials_b
@@ -147,41 +146,42 @@ function SciMLBase.init(
147
146
∂_b = partial_vals (b)
148
147
dual_u0 = partial_vals (u0)
149
148
150
- newprob = LinearProblem (new_A, new_b, u0 = new_u0)
149
+ primal_prob = LinearProblem (new_A, new_b, u0 = new_u0)
151
150
# remake(prob; A = new_A, b = new_b, u0 = new_u0)
152
151
152
+ if get_dual_type (prob. A) != = nothing
153
+ dual_type = get_dual_type (prob. A)
154
+ elseif get_dual_type (prob. b) != = nothing
155
+ dual_type = get_dual_type (prob. b)
156
+ end
157
+
153
158
non_partial_cache = init (
154
- newprob , alg, args... ; alias = alias, abstol = abstol, reltol = reltol,
159
+ primal_prob , alg, args... ; alias = alias, abstol = abstol, reltol = reltol,
155
160
maxiters = maxiters, verbose = verbose, Pl = Pl, Pr = Pr, assumptions = assumptions,
156
161
sensealg = sensealg, u0 = new_u0, kwargs... )
157
- return DualLinearCache (non_partial_cache, prob, alg , dual_u0, ∂_A, ∂_b)
162
+ return DualLinearCache (non_partial_cache, dual_type , dual_u0, ∂_A, ∂_b)
158
163
end
159
164
160
165
function SciMLBase. solve! (cache:: DualLinearCache , args... ; kwargs... )
161
166
sol,
162
167
partials = linearsolve_forwarddiff_solve (
163
168
cache:: DualLinearCache , cache. alg, args... ; kwargs... )
164
169
165
- if get_dual_type (cache. prob. A) != = nothing
166
- dual_type = get_dual_type (cache. prob. A)
167
- elseif get_dual_type (cache. prob. b) != = nothing
168
- dual_type = get_dual_type (cache. prob. b)
169
- end
170
-
171
- dual_sol = linearsolve_dual_solution (sol. u, partials, dual_type)
170
+ dual_sol = linearsolve_dual_solution (sol. u, partials, cache. dual_type)
172
171
173
172
return SciMLBase. build_linear_solution (
174
173
cache. alg, dual_sol, sol. resid, sol. cache; sol. retcode, sol. iters, sol. stats
175
174
)
176
175
end
177
176
178
- # If setting A or b for DualLinearCache, also set it for the underlying LinearCache
177
+ # If setting A or b for DualLinearCache, put the Dual-stripped versions in the LinearCache
178
+ # Also "forwards" setproperty so that
179
179
function Base. setproperty! (dc:: DualLinearCache , sym:: Symbol , val)
180
180
# If the property is A or b, also update it in the LinearCache
181
181
if sym === :A || sym === :b || sym === :u
182
- if hasproperty (dc, :linear_cache )
183
- setproperty! (dc . linear_cache , sym, nodual_value (val) )
184
- end
182
+ setproperty! (dc. linear_cache, sym, nodual_value (val) )
183
+ elseif hasfield (LinearSolve . LinearCache , sym)
184
+ setproperty! (dc . linear_cache, sym, val)
185
185
end
186
186
187
187
# Update the partials if setting A or b
@@ -194,13 +194,12 @@ function Base.setproperty!(dc::DualLinearCache, sym::Symbol, val)
194
194
end
195
195
end
196
196
197
+ # "Forwards" getproperty to LinearCache if necessary
197
198
function Base. getproperty (dc:: DualLinearCache , sym:: Symbol )
198
- if sym === :A
199
- return dc. linear_cache. A
200
- elseif sym === :b
201
- return dc. linear_cache. b
199
+ if hasfield (LinearSolve. LinearCache, sym)
200
+ return getproperty (dc. linear_cache, sym)
202
201
else
203
- getfield (dc,sym)
202
+ return getfield (dc, sym)
204
203
end
205
204
end
206
205
0 commit comments