@@ -37,6 +37,7 @@ LinearSolve.@concrete mutable struct DualLinearCache
37
37
linear_cache
38
38
prob
39
39
alg
40
+ dual_u0
40
41
partials_A
41
42
partials_b
42
43
end
@@ -48,12 +49,13 @@ function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwa
48
49
# Solves Dual partials separately
49
50
∂_A = cache. partials_A
50
51
∂_b = cache. partials_b
52
+ dual_u0 = only (partials_to_list (cache. dual_u0))
51
53
52
54
rhs_list = xp_linsolve_rhs (uu, ∂_A, ∂_b)
53
55
54
56
new_A = nodual_value (cache. A)
55
57
partial_prob = LinearProblem (new_A, rhs_list[1 ])
56
- partial_cache = init (partial_prob, alg, args... ; kwargs... )
58
+ partial_cache = init (partial_prob, alg, args... ; u0 = dual_u0, kwargs... )
57
59
58
60
for i in eachindex (rhs_list)
59
61
partial_cache. b = rhs_list[i]
@@ -130,20 +132,23 @@ function SciMLBase.init(
130
132
sensealg = LinearSolveAdjoint (),
131
133
kwargs... )
132
134
133
- new_A = nodual_value (prob. A)
134
- new_b = nodual_value (prob. b)
135
+ (; A, b, u0, p) = prob
135
136
136
- ∂_A = partial_vals (prob. A)
137
- ∂_b = partial_vals (prob. b)
137
+ new_A = nodual_value (A)
138
+ new_b = nodual_value (b)
139
+ new_u0 = nodual_value (u0)
140
+
141
+ ∂_A = partial_vals (A)
142
+ ∂_b = partial_vals (b)
143
+ dual_u0 = partial_vals (u0)
138
144
139
145
newprob = remake (prob; A = new_A, b = new_b)
140
146
141
147
non_partial_cache = init (
142
148
newprob, alg, args... ; alias = alias, abstol = abstol, reltol = reltol,
143
149
maxiters = maxiters, verbose = verbose, Pl = Pl, Pr = Pr, assumptions = assumptions,
144
- sensealg = sensealg, kwargs... )
145
-
146
- return DualLinearCache (non_partial_cache, prob, alg, ∂_A, ∂_b)
150
+ sensealg = sensealg, u0 = new_u0, kwargs... )
151
+ return DualLinearCache (non_partial_cache, prob, alg, dual_u0, ∂_A, ∂_b)
147
152
end
148
153
149
154
function SciMLBase. solve! (cache:: DualLinearCache , args... ; kwargs... )
0 commit comments