@@ -32,23 +32,18 @@ const DualBLinearProblem = LinearProblem{
3232const DualAbstractLinearProblem = Union{
3333 DualLinearProblem, DualALinearProblem, DualBLinearProblem}
3434
35- function linearsolve_forwarddiff_solve (prob:: LinearProblem , alg, args... ; kwargs... )
36- new_A = nodual_value (prob. A)
37- new_b = nodual_value (prob. b)
38-
39- newprob = remake (prob; A = new_A, b = new_b)
40-
41- sol = solve (newprob, alg, args... ; kwargs... )
35+ function linearsolve_forwarddiff_solve (cache:: DualLinearCache , alg, args... ; kwargs... )
36+ sol = solve! (cache, alg, args... ; kwargs... )
4237 uu = sol. u
4338
4439 # Solves Dual partials separately
45- ∂_A = partial_vals (prob . A)
46- ∂_b = partial_vals (prob . b)
40+ ∂_A = cache . partials_A
41+ ∂_b = cache . partials_b
4742
4843 rhs_list = xp_linsolve_rhs (uu, ∂_A, ∂_b)
4944
5045 partial_sols = map (rhs_list) do rhs
51- partial_prob = remake (newprob , b = rhs)
46+ partial_prob = remake (partial_prob , b = rhs)
5247 solve (partial_prob, alg, args... ; kwargs... ). u
5348 end
5449
6661
6762function SciMLBase. solve (prob:: DualAbstractLinearProblem ,
6863 alg:: LinearSolve.SciMLLinearSolveAlgorithm , args... ; kwargs... )
69- sol, partials = linearsolve_forwarddiff_solve (
70- prob, alg, args... ; kwargs...
71- )
72-
73- if get_dual_type (prob. A) != = nothing
74- dual_type = get_dual_type (prob. A)
75- elseif get_dual_type (prob. b) != = nothing
76- dual_type = get_dual_type (prob. b)
77- end
78-
79- dual_sol = linearsolve_dual_solution (sol. u, partials, dual_type)
80-
81- return SciMLBase. build_linear_solution (
82- alg, dual_sol, sol. resid, sol. cache; sol. retcode, sol. iters, sol. stats
83- )
64+ solve! (init (prob, alg, args... ; kwargs... ))
8465end
8566
8667function linearsolve_dual_solution (
@@ -154,4 +135,57 @@ function partials_to_list(partial_matrix)
154135 return res_list
155136end
156137
138+ function SciMLBase. init (prob:: DualAbstractLinearProblem , alg:: SciMLLinearSolveAlgorithm ,
139+ args... ;
140+ alias = LinearAliasSpecifier (),
141+ abstol = default_tol (real (eltype (prob. b))),
142+ reltol = default_tol (real (eltype (prob. b))),
143+ maxiters:: Int = length (prob. b),
144+ verbose:: Bool = false ,
145+ Pl = nothing ,
146+ Pr = nothing ,
147+ assumptions = OperatorAssumptions (issquare (prob. A)),
148+ sensealg = LinearSolveAdjoint (),
149+ kwargs... )
150+
151+ new_A = nodual_value (prob. A)
152+ new_b = nodual_value (prob. b)
153+
154+ ∂_A = partial_vals (prob. A)
155+ ∂_b = partial_vals (prob. b)
156+
157+ newprob = remake (prob; A = new_A, b = new_b)
158+
159+ non_partial_cache = init (newprob, alg, args... ; alias = alias, abstol = abstol, reltol = reltol,
160+ maxiters = maxiters, verbose = verbose, Pl = Pl, Pr = Pr, assumptions = assumptions,
161+ sensealg = sensealg, kwargs... )
162+
163+ return DualLinearCache (non_partial_cache, prob, alg, ∂_A, ∂_b)
164+ end
165+
166+ mutable struct DualLinearCache
167+ cache
168+ prob
169+ alg
170+ partials_A
171+ partials_b
172+ end
173+
174+ function SciMLBase. solve! (cache:: DualLinearCache , args... ; kwargs... )
175+
176+ sol, partials = linearsolve_forwarddiff_solve (cache:: DualLinearCache , alg, args... ; kwargs... )
177+
178+ if get_dual_type (cache. prob. A) != = nothing
179+ dual_type = get_dual_type (prob. A)
180+ elseif get_dual_type (cache. prob. b) != = nothing
181+ dual_type = get_dual_type (prob. b)
182+ end
183+
184+ dual_sol = linearsolve_dual_solution (sol. u, partials, dual_type)
185+
186+ return SciMLBase. build_linear_solution (
187+ alg, dual_sol, sol. resid, sol. cache; sol. retcode, sol. iters, sol. stats
188+ )
189+ end
190+
157191end
0 commit comments