Skip to content

Commit 9b69358

Browse files
committed
make sure u0 is correct type
1 parent e5761c8 commit 9b69358

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

ext/LinearSolveForwardDiffExt.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,12 +51,12 @@ function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwa
5151
# Solves Dual partials separately
5252
∂_A = cache.partials_A
5353
∂_b = cache.partials_b
54-
dual_u0 = only(partials_to_list(cache.dual_u0))
54+
dual_u0 = !isnothing(cache.dual_u0) ? only(partials_to_list(cache.dual_u0)) : cache.linear_cache.u
5555

5656
rhs_list = xp_linsolve_rhs(uu, ∂_A, ∂_b)
5757

5858
partial_cache = cache.linear_cache
59-
partial_cache.u0 = dual_u0
59+
partial_cache.u = dual_u0
6060
for i in eachindex(rhs_list)
6161
partial_cache.b = rhs_list[i]
6262
rhs_list[i] = copy(solve!(partial_cache, alg).u)
@@ -142,7 +142,8 @@ function SciMLBase.init(
142142
∂_b = partial_vals(b)
143143
dual_u0 = partial_vals(u0)
144144

145-
newprob = remake(prob; A = new_A, b = new_b, u0 = new_u0)
145+
newprob = LinearProblem(new_A, new_b, u0 = new_u0)
146+
#remake(prob; A = new_A, b = new_b, u0 = new_u0)
146147

147148
non_partial_cache = init(
148149
newprob, alg, args...; alias = alias, abstol = abstol, reltol = reltol,

0 commit comments

Comments
 (0)