@@ -32,8 +32,16 @@ const DualBLinearProblem = LinearProblem{
3232const DualAbstractLinearProblem = Union{
3333 DualLinearProblem, DualALinearProblem, DualBLinearProblem}
3434
35+ LinearSolve. @concrete mutable struct DualLinearCache
36+ cache
37+ prob
38+ alg
39+ partials_A
40+ partials_b
41+ end
42+
3543function linearsolve_forwarddiff_solve (cache:: DualLinearCache , alg, args... ; kwargs... )
36- sol = solve! (cache, alg, args... ; kwargs... )
44+ sol = solve! (cache. cache , alg, args... ; kwargs... )
3745 uu = sol. u
3846
3947 # Solves Dual partials separately
@@ -42,11 +50,16 @@ function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwa
4250
4351 rhs_list = xp_linsolve_rhs (uu, ∂_A, ∂_b)
4452
45- partial_sols = map (rhs_list) do rhs
46- partial_prob = remake (partial_prob, b = rhs)
47- solve (partial_prob, alg, args... ; kwargs... ). u
53+ partial_prob = LinearProblem (cache. cache. A, rhs_list[1 ])
54+ partial_cache = init (partial_prob, alg, args... ; kwargs... )
55+
56+ for i in eachindex (rhs_list)
57+ partial_cache. b = rhs_list[i]
58+ rhs_list[i] = copy (solve! (partial_cache, alg). u)
4859 end
4960
61+ partial_sols = rhs_list
62+
5063 sol, partial_sols
5164end
5265
@@ -135,19 +148,19 @@ function partials_to_list(partial_matrix)
135148 return res_list
136149end
137150
138- function SciMLBase. init (prob :: DualAbstractLinearProblem , alg :: SciMLLinearSolveAlgorithm ,
139- args ... ;
140- alias = LinearAliasSpecifier (),
141- abstol = default_tol ( real ( eltype (prob . b)) ),
142- reltol = default_tol (real (eltype (prob. b))),
143- maxiters :: Int = length ( prob. b),
144- verbose :: Bool = false ,
145- Pl = nothing ,
146- Pr = nothing ,
147- assumptions = OperatorAssumptions ( issquare (prob . A)) ,
148- sensealg = LinearSolveAdjoint ( ),
149- kwargs ... )
150-
151+ function SciMLBase. init (
152+ prob :: DualAbstractLinearProblem , alg :: LinearSolve.SciMLLinearSolveAlgorithm ,
153+ args ... ;
154+ alias = LinearAliasSpecifier ( ),
155+ abstol = LinearSolve . default_tol (real (eltype (prob. b))),
156+ reltol = LinearSolve . default_tol ( real ( eltype ( prob. b)) ),
157+ maxiters :: Int = length (prob . b) ,
158+ verbose :: Bool = false ,
159+ Pl = nothing ,
160+ Pr = nothing ,
161+ assumptions = OperatorAssumptions ( issquare (prob . A) ),
162+ sensealg = LinearSolveAdjoint (),
163+ kwargs ... )
151164 new_A = nodual_value (prob. A)
152165 new_b = nodual_value (prob. b)
153166
@@ -156,35 +169,28 @@ function SciMLBase.init(prob::DualAbstractLinearProblem, alg::SciMLLinearSolveAl
156169
157170 newprob = remake (prob; A = new_A, b = new_b)
158171
159- non_partial_cache = init (newprob, alg, args... ; alias = alias, abstol = abstol, reltol = reltol,
172+ non_partial_cache = init (
173+ newprob, alg, args... ; alias = alias, abstol = abstol, reltol = reltol,
160174 maxiters = maxiters, verbose = verbose, Pl = Pl, Pr = Pr, assumptions = assumptions,
161175 sensealg = sensealg, kwargs... )
162176
163177 return DualLinearCache (non_partial_cache, prob, alg, ∂_A, ∂_b)
164178end
165179
166- mutable struct DualLinearCache
167- cache
168- prob
169- alg
170- partials_A
171- partials_b
172- end
173-
174180function SciMLBase. solve! (cache:: DualLinearCache , args... ; kwargs... )
175-
176- sol, partials = linearsolve_forwarddiff_solve ( cache:: DualLinearCache , alg, args... ; kwargs... )
181+ sol, partials = linearsolve_forwarddiff_solve (
182+ cache:: DualLinearCache , cache . alg, args... ; kwargs... )
177183
178184 if get_dual_type (cache. prob. A) != = nothing
179- dual_type = get_dual_type (prob. A)
185+ dual_type = get_dual_type (cache . prob. A)
180186 elseif get_dual_type (cache. prob. b) != = nothing
181- dual_type = get_dual_type (prob. b)
187+ dual_type = get_dual_type (cache . prob. b)
182188 end
183189
184190 dual_sol = linearsolve_dual_solution (sol. u, partials, dual_type)
185191
186192 return SciMLBase. build_linear_solution (
187- alg, dual_sol, sol. resid, sol. cache; sol. retcode, sol. iters, sol. stats
193+ cache . alg, dual_sol, sol. resid, sol. cache; sol. retcode, sol. iters, sol. stats
188194 )
189195end
190196
0 commit comments