Skip to content

Commit 51ce056

Browse files
committed
make sure that linearcache.b is reset after dual solve
1 parent f55639a commit 51ce056

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

ext/LinearSolveForwardDiffExt.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ end
4444

4545
function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwargs...)
4646
sol = solve!(cache.linear_cache, alg, args...; kwargs...)
47+
primal_b = copy(cache.linear_cache.b)
4748
uu = sol.u
4849

4950
primal_sol = deepcopy(sol)
@@ -57,11 +58,15 @@ function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwa
5758

5859
partial_cache = cache.linear_cache
5960
partial_cache.u = dual_u0
61+
6062
for i in eachindex(rhs_list)
6163
partial_cache.b = rhs_list[i]
62-
rhs_list[i] = copy(solve!(partial_cache, alg).u)
64+
rhs_list[i] = copy(solve!(partial_cache, alg, args...; kwargs...).u)
6365
end
6466

67+
# Reset to the original `b`, users will expect that `b` doesn't change if they don't tell it to
68+
partial_cache.b = primal_b
69+
6570
partial_sols = rhs_list
6671

6772
primal_sol, partial_sols
@@ -173,7 +178,7 @@ end
173178
# If setting A or b for DualLinearCache, also set it for the underlying LinearCache
174179
function Base.setproperty!(dc::DualLinearCache, sym::Symbol, val)
175180
# If the property is A or b, also update it in the LinearCache
176-
if sym === :A || sym === :b
181+
if sym === :A || sym === :b || sym === :u
177182
if hasproperty(dc, :linear_cache)
178183
setproperty!(dc.linear_cache, sym, nodual_value(val))
179184
end

0 commit comments

Comments
 (0)