Skip to content

Commit 501f07d

Browse files
committed
rearrange
1 parent e6cda65 commit 501f07d

File tree

1 file changed

+37
-31
lines changed

1 file changed

+37
-31
lines changed

ext/LinearSolveForwardDiffExt.jl

Lines changed: 37 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,16 @@ const DualBLinearProblem = LinearProblem{
3232
const DualAbstractLinearProblem = Union{
3333
DualLinearProblem, DualALinearProblem, DualBLinearProblem}
3434

35+
LinearSolve.@concrete mutable struct DualLinearCache
36+
cache
37+
prob
38+
alg
39+
partials_A
40+
partials_b
41+
end
42+
3543
function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwargs...)
36-
sol = solve!(cache, alg, args...; kwargs...)
44+
sol = solve!(cache.cache, alg, args...; kwargs...)
3745
uu = sol.u
3846

3947
# Solves Dual partials separately
@@ -42,11 +50,16 @@ function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwa
4250

4351
rhs_list = xp_linsolve_rhs(uu, ∂_A, ∂_b)
4452

45-
partial_sols = map(rhs_list) do rhs
46-
partial_prob = remake(partial_prob, b = rhs)
47-
solve(partial_prob, alg, args...; kwargs...).u
53+
partial_prob = LinearProblem(cache.cache.A, rhs_list[1])
54+
partial_cache = init(partial_prob, alg, args...; kwargs...)
55+
56+
for i in eachindex(rhs_list)
57+
partial_cache.b = rhs_list[i]
58+
rhs_list[i] = copy(solve!(partial_cache, alg).u)
4859
end
4960

61+
partial_sols = rhs_list
62+
5063
sol, partial_sols
5164
end
5265

@@ -135,19 +148,19 @@ function partials_to_list(partial_matrix)
135148
return res_list
136149
end
137150

138-
function SciMLBase.init(prob::DualAbstractLinearProblem, alg::SciMLLinearSolveAlgorithm,
139-
args...;
140-
alias = LinearAliasSpecifier(),
141-
abstol = default_tol(real(eltype(prob.b))),
142-
reltol = default_tol(real(eltype(prob.b))),
143-
maxiters::Int = length(prob.b),
144-
verbose::Bool = false,
145-
Pl = nothing,
146-
Pr = nothing,
147-
assumptions = OperatorAssumptions(issquare(prob.A)),
148-
sensealg = LinearSolveAdjoint(),
149-
kwargs...)
150-
151+
function SciMLBase.init(
152+
prob::DualAbstractLinearProblem, alg::LinearSolve.SciMLLinearSolveAlgorithm,
153+
args...;
154+
alias = LinearAliasSpecifier(),
155+
abstol = LinearSolve.default_tol(real(eltype(prob.b))),
156+
reltol = LinearSolve.default_tol(real(eltype(prob.b))),
157+
maxiters::Int = length(prob.b),
158+
verbose::Bool = false,
159+
Pl = nothing,
160+
Pr = nothing,
161+
assumptions = OperatorAssumptions(issquare(prob.A)),
162+
sensealg = LinearSolveAdjoint(),
163+
kwargs...)
151164
new_A = nodual_value(prob.A)
152165
new_b = nodual_value(prob.b)
153166

@@ -156,35 +169,28 @@ function SciMLBase.init(prob::DualAbstractLinearProblem, alg::SciMLLinearSolveAl
156169

157170
newprob = remake(prob; A = new_A, b = new_b)
158171

159-
non_partial_cache = init(newprob, alg, args...; alias = alias, abstol = abstol, reltol = reltol,
172+
non_partial_cache = init(
173+
newprob, alg, args...; alias = alias, abstol = abstol, reltol = reltol,
160174
maxiters = maxiters, verbose = verbose, Pl = Pl, Pr = Pr, assumptions = assumptions,
161175
sensealg = sensealg, kwargs...)
162176

163177
return DualLinearCache(non_partial_cache, prob, alg, ∂_A, ∂_b)
164178
end
165179

166-
mutable struct DualLinearCache
167-
cache
168-
prob
169-
alg
170-
partials_A
171-
partials_b
172-
end
173-
174180
function SciMLBase.solve!(cache::DualLinearCache, args...; kwargs...)
175-
176-
sol, partials = linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwargs...)
181+
sol, partials = linearsolve_forwarddiff_solve(
182+
cache::DualLinearCache, cache.alg, args...; kwargs...)
177183

178184
if get_dual_type(cache.prob.A) !== nothing
179-
dual_type = get_dual_type(prob.A)
185+
dual_type = get_dual_type(cache.prob.A)
180186
elseif get_dual_type(cache.prob.b) !== nothing
181-
dual_type = get_dual_type(prob.b)
187+
dual_type = get_dual_type(cache.prob.b)
182188
end
183189

184190
dual_sol = linearsolve_dual_solution(sol.u, partials, dual_type)
185191

186192
return SciMLBase.build_linear_solution(
187-
alg, dual_sol, sol.resid, sol.cache; sol.retcode, sol.iters, sol.stats
193+
cache.alg, dual_sol, sol.resid, sol.cache; sol.retcode, sol.iters, sol.stats
188194
)
189195
end
190196

0 commit comments

Comments
 (0)