@@ -32,23 +32,18 @@ const DualBLinearProblem = LinearProblem{
32
32
const DualAbstractLinearProblem = Union{
33
33
DualLinearProblem, DualALinearProblem, DualBLinearProblem}
34
34
35
- function linearsolve_forwarddiff_solve (prob:: LinearProblem , alg, args... ; kwargs... )
36
- new_A = nodual_value (prob. A)
37
- new_b = nodual_value (prob. b)
38
-
39
- newprob = remake (prob; A = new_A, b = new_b)
40
-
41
- sol = solve (newprob, alg, args... ; kwargs... )
35
+ function linearsolve_forwarddiff_solve (cache:: DualLinearCache , alg, args... ; kwargs... )
36
+ sol = solve! (cache, alg, args... ; kwargs... )
42
37
uu = sol. u
43
38
44
39
# Solves Dual partials separately
45
- ∂_A = partial_vals (prob . A)
46
- ∂_b = partial_vals (prob . b)
40
+ ∂_A = cache . partials_A
41
+ ∂_b = cache . partials_b
47
42
48
43
rhs_list = xp_linsolve_rhs (uu, ∂_A, ∂_b)
49
44
50
45
partial_sols = map (rhs_list) do rhs
51
- partial_prob = remake (newprob , b = rhs)
46
+ partial_prob = remake (partial_prob , b = rhs)
52
47
solve (partial_prob, alg, args... ; kwargs... ). u
53
48
end
54
49
66
61
67
62
function SciMLBase. solve (prob:: DualAbstractLinearProblem ,
68
63
alg:: LinearSolve.SciMLLinearSolveAlgorithm , args... ; kwargs... )
69
- sol, partials = linearsolve_forwarddiff_solve (
70
- prob, alg, args... ; kwargs...
71
- )
72
-
73
- if get_dual_type (prob. A) != = nothing
74
- dual_type = get_dual_type (prob. A)
75
- elseif get_dual_type (prob. b) != = nothing
76
- dual_type = get_dual_type (prob. b)
77
- end
78
-
79
- dual_sol = linearsolve_dual_solution (sol. u, partials, dual_type)
80
-
81
- return SciMLBase. build_linear_solution (
82
- alg, dual_sol, sol. resid, sol. cache; sol. retcode, sol. iters, sol. stats
83
- )
64
+ solve! (init (prob, alg, args... ; kwargs... ))
84
65
end
85
66
86
67
function linearsolve_dual_solution (
@@ -154,4 +135,57 @@ function partials_to_list(partial_matrix)
154
135
return res_list
155
136
end
156
137
138
+ function SciMLBase. init (prob:: DualAbstractLinearProblem , alg:: SciMLLinearSolveAlgorithm ,
139
+ args... ;
140
+ alias = LinearAliasSpecifier (),
141
+ abstol = default_tol (real (eltype (prob. b))),
142
+ reltol = default_tol (real (eltype (prob. b))),
143
+ maxiters:: Int = length (prob. b),
144
+ verbose:: Bool = false ,
145
+ Pl = nothing ,
146
+ Pr = nothing ,
147
+ assumptions = OperatorAssumptions (issquare (prob. A)),
148
+ sensealg = LinearSolveAdjoint (),
149
+ kwargs... )
150
+
151
+ new_A = nodual_value (prob. A)
152
+ new_b = nodual_value (prob. b)
153
+
154
+ ∂_A = partial_vals (prob. A)
155
+ ∂_b = partial_vals (prob. b)
156
+
157
+ newprob = remake (prob; A = new_A, b = new_b)
158
+
159
+ non_partial_cache = init (newprob, alg, args... ; alias = alias, abstol = abstol, reltol = reltol,
160
+ maxiters = maxiters, verbose = verbose, Pl = Pl, Pr = Pr, assumptions = assumptions,
161
+ sensealg = sensealg, kwargs... )
162
+
163
+ return DualLinearCache (non_partial_cache, prob, alg, ∂_A, ∂_b)
164
+ end
165
+
166
+ mutable struct DualLinearCache
167
+ cache
168
+ prob
169
+ alg
170
+ partials_A
171
+ partials_b
172
+ end
173
+
174
+ function SciMLBase. solve! (cache:: DualLinearCache , args... ; kwargs... )
175
+
176
+ sol, partials = linearsolve_forwarddiff_solve (cache:: DualLinearCache , alg, args... ; kwargs... )
177
+
178
+ if get_dual_type (cache. prob. A) != = nothing
179
+ dual_type = get_dual_type (prob. A)
180
+ elseif get_dual_type (cache. prob. b) != = nothing
181
+ dual_type = get_dual_type (prob. b)
182
+ end
183
+
184
+ dual_sol = linearsolve_dual_solution (sol. u, partials, dual_type)
185
+
186
+ return SciMLBase. build_linear_solution (
187
+ alg, dual_sol, sol. resid, sol. cache; sol. retcode, sol. iters, sol. stats
188
+ )
189
+ end
190
+
157
191
end
0 commit comments