@@ -32,8 +32,16 @@ const DualBLinearProblem = LinearProblem{
32
32
const DualAbstractLinearProblem = Union{
33
33
DualLinearProblem, DualALinearProblem, DualBLinearProblem}
34
34
35
+ LinearSolve. @concrete mutable struct DualLinearCache
36
+ cache
37
+ prob
38
+ alg
39
+ partials_A
40
+ partials_b
41
+ end
42
+
35
43
function linearsolve_forwarddiff_solve (cache:: DualLinearCache , alg, args... ; kwargs... )
36
- sol = solve! (cache, alg, args... ; kwargs... )
44
+ sol = solve! (cache. cache , alg, args... ; kwargs... )
37
45
uu = sol. u
38
46
39
47
# Solves Dual partials separately
@@ -42,11 +50,16 @@ function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwa
42
50
43
51
rhs_list = xp_linsolve_rhs (uu, ∂_A, ∂_b)
44
52
45
- partial_sols = map (rhs_list) do rhs
46
- partial_prob = remake (partial_prob, b = rhs)
47
- solve (partial_prob, alg, args... ; kwargs... ). u
53
+ partial_prob = LinearProblem (cache. cache. A, rhs_list[1 ])
54
+ partial_cache = init (partial_prob, alg, args... ; kwargs... )
55
+
56
+ for i in eachindex (rhs_list)
57
+ partial_cache. b = rhs_list[i]
58
+ rhs_list[i] = copy (solve! (partial_cache, alg). u)
48
59
end
49
60
61
+ partial_sols = rhs_list
62
+
50
63
sol, partial_sols
51
64
end
52
65
@@ -135,19 +148,19 @@ function partials_to_list(partial_matrix)
135
148
return res_list
136
149
end
137
150
138
- function SciMLBase. init (prob :: DualAbstractLinearProblem , alg :: SciMLLinearSolveAlgorithm ,
139
- args ... ;
140
- alias = LinearAliasSpecifier (),
141
- abstol = default_tol ( real ( eltype (prob . b)) ),
142
- reltol = default_tol (real (eltype (prob. b))),
143
- maxiters :: Int = length ( prob. b),
144
- verbose :: Bool = false ,
145
- Pl = nothing ,
146
- Pr = nothing ,
147
- assumptions = OperatorAssumptions ( issquare (prob . A)) ,
148
- sensealg = LinearSolveAdjoint ( ),
149
- kwargs ... )
150
-
151
+ function SciMLBase. init (
152
+ prob :: DualAbstractLinearProblem , alg :: LinearSolve.SciMLLinearSolveAlgorithm ,
153
+ args ... ;
154
+ alias = LinearAliasSpecifier ( ),
155
+ abstol = LinearSolve . default_tol (real (eltype (prob. b))),
156
+ reltol = LinearSolve . default_tol ( real ( eltype ( prob. b)) ),
157
+ maxiters :: Int = length (prob . b) ,
158
+ verbose :: Bool = false ,
159
+ Pl = nothing ,
160
+ Pr = nothing ,
161
+ assumptions = OperatorAssumptions ( issquare (prob . A) ),
162
+ sensealg = LinearSolveAdjoint (),
163
+ kwargs ... )
151
164
new_A = nodual_value (prob. A)
152
165
new_b = nodual_value (prob. b)
153
166
@@ -156,35 +169,28 @@ function SciMLBase.init(prob::DualAbstractLinearProblem, alg::SciMLLinearSolveAl
156
169
157
170
newprob = remake (prob; A = new_A, b = new_b)
158
171
159
- non_partial_cache = init (newprob, alg, args... ; alias = alias, abstol = abstol, reltol = reltol,
172
+ non_partial_cache = init (
173
+ newprob, alg, args... ; alias = alias, abstol = abstol, reltol = reltol,
160
174
maxiters = maxiters, verbose = verbose, Pl = Pl, Pr = Pr, assumptions = assumptions,
161
175
sensealg = sensealg, kwargs... )
162
176
163
177
return DualLinearCache (non_partial_cache, prob, alg, ∂_A, ∂_b)
164
178
end
165
179
166
- mutable struct DualLinearCache
167
- cache
168
- prob
169
- alg
170
- partials_A
171
- partials_b
172
- end
173
-
174
180
function SciMLBase. solve! (cache:: DualLinearCache , args... ; kwargs... )
175
-
176
- sol, partials = linearsolve_forwarddiff_solve ( cache:: DualLinearCache , alg, args... ; kwargs... )
181
+ sol, partials = linearsolve_forwarddiff_solve (
182
+ cache:: DualLinearCache , cache . alg, args... ; kwargs... )
177
183
178
184
if get_dual_type (cache. prob. A) != = nothing
179
- dual_type = get_dual_type (prob. A)
185
+ dual_type = get_dual_type (cache . prob. A)
180
186
elseif get_dual_type (cache. prob. b) != = nothing
181
- dual_type = get_dual_type (prob. b)
187
+ dual_type = get_dual_type (cache . prob. b)
182
188
end
183
189
184
190
dual_sol = linearsolve_dual_solution (sol. u, partials, dual_type)
185
191
186
192
return SciMLBase. build_linear_solution (
187
- alg, dual_sol, sol. resid, sol. cache; sol. retcode, sol. iters, sol. stats
193
+ cache . alg, dual_sol, sol. resid, sol. cache; sol. retcode, sol. iters, sol. stats
188
194
)
189
195
end
190
196
0 commit comments