@@ -37,6 +37,7 @@ LinearSolve.@concrete mutable struct DualLinearCache
3737 linear_cache
3838 prob
3939 alg
40+ dual_u0
4041 partials_A
4142 partials_b
4243end
@@ -48,12 +49,13 @@ function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwa
4849 # Solves Dual partials separately
4950 ∂_A = cache. partials_A
5051 ∂_b = cache. partials_b
52+ dual_u0 = only (partials_to_list (cache. dual_u0))
5153
5254 rhs_list = xp_linsolve_rhs (uu, ∂_A, ∂_b)
5355
5456 new_A = nodual_value (cache. A)
5557 partial_prob = LinearProblem (new_A, rhs_list[1 ])
56- partial_cache = init (partial_prob, alg, args... ; kwargs... )
58+ partial_cache = init (partial_prob, alg, args... ; u0 = dual_u0, kwargs... )
5759
5860 for i in eachindex (rhs_list)
5961 partial_cache. b = rhs_list[i]
@@ -130,20 +132,23 @@ function SciMLBase.init(
130132 sensealg = LinearSolveAdjoint (),
131133 kwargs... )
132134
133- new_A = nodual_value (prob. A)
134- new_b = nodual_value (prob. b)
135+ (; A, b, u0, p) = prob
135136
136- ∂_A = partial_vals (prob. A)
137- ∂_b = partial_vals (prob. b)
137+ new_A = nodual_value (A)
138+ new_b = nodual_value (b)
139+ new_u0 = nodual_value (u0)
140+
141+ ∂_A = partial_vals (A)
142+ ∂_b = partial_vals (b)
143+ dual_u0 = partial_vals (u0)
138144
139145 newprob = remake (prob; A = new_A, b = new_b)
140146
141147 non_partial_cache = init (
142148 newprob, alg, args... ; alias = alias, abstol = abstol, reltol = reltol,
143149 maxiters = maxiters, verbose = verbose, Pl = Pl, Pr = Pr, assumptions = assumptions,
144- sensealg = sensealg, kwargs... )
145-
146- return DualLinearCache (non_partial_cache, prob, alg, ∂_A, ∂_b)
150+ sensealg = sensealg, u0 = new_u0, kwargs... )
151+ return DualLinearCache (non_partial_cache, prob, alg, dual_u0, ∂_A, ∂_b)
147152end
148153
149154function SciMLBase. solve! (cache:: DualLinearCache , args... ; kwargs... )
0 commit comments