Skip to content

Commit b39ce87

Browse files
committed
reuse primal cache for Dual computation
1 parent 1b48666 commit b39ce87

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

ext/LinearSolveForwardDiffExt.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwa
4646
sol = solve!(cache.linear_cache, alg, args...; kwargs...)
4747
uu = sol.u
4848

49+
primal_sol = deepcopy(sol)
50+
4951
# Solves Dual partials separately
5052
∂_A = cache.partials_A
5153
∂_b = cache.partials_b
@@ -54,17 +56,16 @@ function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwa
5456
rhs_list = xp_linsolve_rhs(uu, ∂_A, ∂_b)
5557

5658
new_A = nodual_value(cache.A)
57-
partial_prob = LinearProblem(new_A, rhs_list[1])
58-
partial_cache = init(partial_prob, alg, args...; u0 = dual_u0, kwargs...)
59-
59+
partial_cache = cache.linear_cache
60+
partial_cache.u0 = dual_u0
6061
for i in eachindex(rhs_list)
6162
partial_cache.b = rhs_list[i]
6263
rhs_list[i] = copy(solve!(partial_cache, alg).u)
6364
end
6465

6566
partial_sols = rhs_list
6667

67-
sol, partial_sols
68+
primal_sol, partial_sols
6869
end
6970

7071
function xp_linsolve_rhs(uu, ∂_A::Union{<:Partials, <:AbstractArray{<:Partials}},

0 commit comments

Comments
 (0)