Skip to content

Commit e6cda65

Browse files
committed
use inits and caches
1 parent 6352024 commit e6cda65

File tree

1 file changed

+59
-25
lines changed

1 file changed

+59
-25
lines changed

ext/LinearSolveForwardDiffExt.jl

Lines changed: 59 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -32,23 +32,18 @@ const DualBLinearProblem = LinearProblem{
3232
const DualAbstractLinearProblem = Union{
3333
DualLinearProblem, DualALinearProblem, DualBLinearProblem}
3434

35-
function linearsolve_forwarddiff_solve(prob::LinearProblem, alg, args...; kwargs...)
36-
new_A = nodual_value(prob.A)
37-
new_b = nodual_value(prob.b)
38-
39-
newprob = remake(prob; A = new_A, b = new_b)
40-
41-
sol = solve(newprob, alg, args...; kwargs...)
35+
function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwargs...)
36+
sol = solve!(cache, alg, args...; kwargs...)
4237
uu = sol.u
4338

4439
# Solves Dual partials separately
45-
∂_A = partial_vals(prob.A)
46-
∂_b = partial_vals(prob.b)
40+
∂_A = cache.partials_A
41+
∂_b = cache.partials_b
4742

4843
rhs_list = xp_linsolve_rhs(uu, ∂_A, ∂_b)
4944

5045
partial_sols = map(rhs_list) do rhs
51-
partial_prob = remake(newprob, b = rhs)
46+
partial_prob = remake(partial_prob, b = rhs)
5247
solve(partial_prob, alg, args...; kwargs...).u
5348
end
5449

@@ -66,21 +61,7 @@ end
6661

6762
function SciMLBase.solve(prob::DualAbstractLinearProblem,
6863
alg::LinearSolve.SciMLLinearSolveAlgorithm, args...; kwargs...)
69-
sol, partials = linearsolve_forwarddiff_solve(
70-
prob, alg, args...; kwargs...
71-
)
72-
73-
if get_dual_type(prob.A) !== nothing
74-
dual_type = get_dual_type(prob.A)
75-
elseif get_dual_type(prob.b) !== nothing
76-
dual_type = get_dual_type(prob.b)
77-
end
78-
79-
dual_sol = linearsolve_dual_solution(sol.u, partials, dual_type)
80-
81-
return SciMLBase.build_linear_solution(
82-
alg, dual_sol, sol.resid, sol.cache; sol.retcode, sol.iters, sol.stats
83-
)
64+
solve!(init(prob, alg, args...; kwargs...))
8465
end
8566

8667
function linearsolve_dual_solution(
@@ -154,4 +135,57 @@ function partials_to_list(partial_matrix)
154135
return res_list
155136
end
156137

138+
function SciMLBase.init(prob::DualAbstractLinearProblem, alg::SciMLLinearSolveAlgorithm,
139+
args...;
140+
alias = LinearAliasSpecifier(),
141+
abstol = default_tol(real(eltype(prob.b))),
142+
reltol = default_tol(real(eltype(prob.b))),
143+
maxiters::Int = length(prob.b),
144+
verbose::Bool = false,
145+
Pl = nothing,
146+
Pr = nothing,
147+
assumptions = OperatorAssumptions(issquare(prob.A)),
148+
sensealg = LinearSolveAdjoint(),
149+
kwargs...)
150+
151+
new_A = nodual_value(prob.A)
152+
new_b = nodual_value(prob.b)
153+
154+
∂_A = partial_vals(prob.A)
155+
∂_b = partial_vals(prob.b)
156+
157+
newprob = remake(prob; A = new_A, b = new_b)
158+
159+
non_partial_cache = init(newprob, alg, args...; alias = alias, abstol = abstol, reltol = reltol,
160+
maxiters = maxiters, verbose = verbose, Pl = Pl, Pr = Pr, assumptions = assumptions,
161+
sensealg = sensealg, kwargs...)
162+
163+
return DualLinearCache(non_partial_cache, prob, alg, ∂_A, ∂_b)
164+
end
165+
166+
mutable struct DualLinearCache
167+
cache
168+
prob
169+
alg
170+
partials_A
171+
partials_b
172+
end
173+
174+
function SciMLBase.solve!(cache::DualLinearCache, args...; kwargs...)
175+
176+
sol, partials = linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwargs...)
177+
178+
if get_dual_type(cache.prob.A) !== nothing
179+
dual_type = get_dual_type(prob.A)
180+
elseif get_dual_type(cache.prob.b) !== nothing
181+
dual_type = get_dual_type(prob.b)
182+
end
183+
184+
dual_sol = linearsolve_dual_solution(sol.u, partials, dual_type)
185+
186+
return SciMLBase.build_linear_solution(
187+
alg, dual_sol, sol.resid, sol.cache; sol.retcode, sol.iters, sol.stats
188+
)
189+
end
190+
157191
end

0 commit comments

Comments
 (0)