@@ -35,8 +35,7 @@ const DualAbstractLinearProblem = Union{
3535
3636LinearSolve. @concrete mutable struct DualLinearCache
3737 linear_cache
38- prob
39- alg
38+ dual_type
4039 dual_u0
4140 partials_A
4241 partials_b
@@ -147,41 +146,42 @@ function SciMLBase.init(
147146 ∂_b = partial_vals (b)
148147 dual_u0 = partial_vals (u0)
149148
150- newprob = LinearProblem (new_A, new_b, u0 = new_u0)
149+ primal_prob = LinearProblem (new_A, new_b, u0 = new_u0)
151150 # remake(prob; A = new_A, b = new_b, u0 = new_u0)
152151
152+ if get_dual_type (prob. A) != = nothing
153+ dual_type = get_dual_type (prob. A)
154+ elseif get_dual_type (prob. b) != = nothing
155+ dual_type = get_dual_type (prob. b)
156+ end
157+
153158 non_partial_cache = init (
154- newprob , alg, args... ; alias = alias, abstol = abstol, reltol = reltol,
159+ primal_prob , alg, args... ; alias = alias, abstol = abstol, reltol = reltol,
155160 maxiters = maxiters, verbose = verbose, Pl = Pl, Pr = Pr, assumptions = assumptions,
156161 sensealg = sensealg, u0 = new_u0, kwargs... )
157- return DualLinearCache (non_partial_cache, prob, alg , dual_u0, ∂_A, ∂_b)
162+ return DualLinearCache (non_partial_cache, dual_type , dual_u0, ∂_A, ∂_b)
158163end
159164
160165function SciMLBase. solve! (cache:: DualLinearCache , args... ; kwargs... )
161166 sol,
162167 partials = linearsolve_forwarddiff_solve (
163168 cache:: DualLinearCache , cache. alg, args... ; kwargs... )
164169
165- if get_dual_type (cache. prob. A) != = nothing
166- dual_type = get_dual_type (cache. prob. A)
167- elseif get_dual_type (cache. prob. b) != = nothing
168- dual_type = get_dual_type (cache. prob. b)
169- end
170-
171- dual_sol = linearsolve_dual_solution (sol. u, partials, dual_type)
170+ dual_sol = linearsolve_dual_solution (sol. u, partials, cache. dual_type)
172171
173172 return SciMLBase. build_linear_solution (
174173 cache. alg, dual_sol, sol. resid, sol. cache; sol. retcode, sol. iters, sol. stats
175174 )
176175end
177176
178- # If setting A or b for DualLinearCache, also set it for the underlying LinearCache
177+ # If setting A or b for DualLinearCache, put the Dual-stripped versions in the LinearCache
178+ # Also "forwards" setproperty so that
179179function Base. setproperty! (dc:: DualLinearCache , sym:: Symbol , val)
180180 # If the property is A or b, also update it in the LinearCache
181181 if sym === :A || sym === :b || sym === :u
182- if hasproperty (dc, :linear_cache )
183- setproperty! (dc . linear_cache , sym, nodual_value (val) )
184- end
182+ setproperty! (dc. linear_cache, sym, nodual_value (val) )
183+ elseif hasfield (LinearSolve . LinearCache , sym)
184+ setproperty! (dc . linear_cache, sym, val)
185185 end
186186
187187 # Update the partials if setting A or b
@@ -194,13 +194,12 @@ function Base.setproperty!(dc::DualLinearCache, sym::Symbol, val)
194194 end
195195end
196196
197+ # "Forwards" getproperty to LinearCache if necessary
197198function Base. getproperty (dc:: DualLinearCache , sym:: Symbol )
198- if sym === :A
199- return dc. linear_cache. A
200- elseif sym === :b
201- return dc. linear_cache. b
199+ if hasfield (LinearSolve. LinearCache, sym)
200+ return getproperty (dc. linear_cache, sym)
202201 else
203- getfield (dc,sym)
202+ return getfield (dc, sym)
204203 end
205204end
206205
0 commit comments