Skip to content

Commit 313e286

Browse files
committed
make sure using nonmutated A
1 parent 277c4f8 commit 313e286

File tree

1 file changed

+5
-6
lines changed

1 file changed

+5
-6
lines changed

ext/LinearSolveForwardDiffExt.jl

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,10 @@ function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwa
5151

5252
rhs_list = xp_linsolve_rhs(uu, ∂_A, ∂_b)
5353

54-
partial_prob = LinearProblem(cache.cache.A, rhs_list[1])
54+
new_A = nodual_value(cache.prob.A)
55+
partial_prob = LinearProblem(new_A, rhs_list[1])
5556
partial_cache = init(partial_prob, alg, args...; kwargs...)
5657

57-
Main.@infiltrate
58-
5958
for i in eachindex(rhs_list)
6059
partial_cache.b = rhs_list[i]
6160
rhs_list[i] = copy(solve!(partial_cache, alg).u)
@@ -110,7 +109,8 @@ function xp_linsolve_rhs(uu, ∂_A::Union{<:Partials, <:AbstractArray{<:Partials
110109
b_list = partials_to_list(∂_b)
111110

112111
Auu = [A * uu for A in A_list]
113-
b_list .- Auu
112+
113+
return b_list .- Auu
114114
end
115115

116116
function xp_linsolve_rhs(
@@ -119,13 +119,12 @@ function xp_linsolve_rhs(
119119

120120
Auu = [A * uu for A in A_list]
121121

122-
-Auu
122+
return -Auu
123123
end
124124

125125
function xp_linsolve_rhs(
126126
uu, ∂_A::Nothing, ∂_b::Union{<:Partials, <:AbstractArray{<:Partials}})
127127
b_list = partials_to_list(∂_b)
128-
Main.@infiltrate
129128
b_list
130129
end
131130

0 commit comments

Comments
 (0)