Skip to content

Commit f9cd2fe

Browse files
committed
enable dual u0
1 parent 680aec6 commit f9cd2fe

File tree

1 file changed

+13
-8
lines changed

1 file changed

+13
-8
lines changed

ext/LinearSolveForwardDiffExt.jl

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ LinearSolve.@concrete mutable struct DualLinearCache
3737
linear_cache
3838
prob
3939
alg
40+
dual_u0
4041
partials_A
4142
partials_b
4243
end
@@ -48,12 +49,13 @@ function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwa
4849
# Solves Dual partials separately
4950
∂_A = cache.partials_A
5051
∂_b = cache.partials_b
52+
dual_u0 = only(partials_to_list(cache.dual_u0))
5153

5254
rhs_list = xp_linsolve_rhs(uu, ∂_A, ∂_b)
5355

5456
new_A = nodual_value(cache.A)
5557
partial_prob = LinearProblem(new_A, rhs_list[1])
56-
partial_cache = init(partial_prob, alg, args...; kwargs...)
58+
partial_cache = init(partial_prob, alg, args...; u0 = dual_u0, kwargs...)
5759

5860
for i in eachindex(rhs_list)
5961
partial_cache.b = rhs_list[i]
@@ -130,20 +132,23 @@ function SciMLBase.init(
130132
sensealg = LinearSolveAdjoint(),
131133
kwargs...)
132134

133-
new_A = nodual_value(prob.A)
134-
new_b = nodual_value(prob.b)
135+
(; A, b, u0, p) = prob
135136

136-
∂_A = partial_vals(prob.A)
137-
∂_b = partial_vals(prob.b)
137+
new_A = nodual_value(A)
138+
new_b = nodual_value(b)
139+
new_u0 = nodual_value(u0)
140+
141+
∂_A = partial_vals(A)
142+
∂_b = partial_vals(b)
143+
dual_u0 = partial_vals(u0)
138144

139145
newprob = remake(prob; A = new_A, b = new_b)
140146

141147
non_partial_cache = init(
142148
newprob, alg, args...; alias = alias, abstol = abstol, reltol = reltol,
143149
maxiters = maxiters, verbose = verbose, Pl = Pl, Pr = Pr, assumptions = assumptions,
144-
sensealg = sensealg, kwargs...)
145-
146-
return DualLinearCache(non_partial_cache, prob, alg, ∂_A, ∂_b)
150+
sensealg = sensealg, u0 = new_u0, kwargs...)
151+
return DualLinearCache(non_partial_cache, prob, alg, dual_u0, ∂_A, ∂_b)
147152
end
148153

149154
function SciMLBase.solve!(cache::DualLinearCache, args...; kwargs...)

0 commit comments

Comments
 (0)