diff --git a/ext/LinearSolveForwardDiffExt.jl b/ext/LinearSolveForwardDiffExt.jl index 01dabe7a2..fb0ec730f 100644 --- a/ext/LinearSolveForwardDiffExt.jl +++ b/ext/LinearSolveForwardDiffExt.jl @@ -34,9 +34,8 @@ const DualBLinearProblem = LinearProblem{ const DualAbstractLinearProblem = Union{ DualLinearProblem, DualALinearProblem, DualBLinearProblem} -LinearSolve.@concrete mutable struct DualLinearCache +LinearSolve.@concrete mutable struct DualLinearCache{DT <: Dual} linear_cache - dual_type partials_A partials_b @@ -54,7 +53,13 @@ function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwa primal_b = copy(cache.linear_cache.b) uu = sol.u - primal_sol = deepcopy(sol) + primal_sol = (; + u = recursivecopy(sol.u), + resid = recursivecopy(sol.resid), + retcode = recursivecopy(sol.retcode), + iters = recursivecopy(sol.iters), + stats = recursivecopy(sol.stats) + ) # Solves Dual partials separately ∂_A = cache.partials_A @@ -103,21 +108,15 @@ function xp_linsolve_rhs( end function linearsolve_dual_solution( - u::Number, partials, dual_type) - return dual_type(u, partials) -end - -function linearsolve_dual_solution(u::Number, partials, - dual_type::Type{<:Dual{T, V, P}}) where {T, V, P} - # Handle single-level duals - return dual_type(u, partials) + u::Number, partials, cache::DualLinearCache{DT}) where {DT} + return DT(u, partials) end function linearsolve_dual_solution(u::AbstractArray, partials, - dual_type::Type{<:Dual{T, V, P}}) where {T, V, P} + cache::DualLinearCache{DT}) where {DT} # Handle single-level duals for arrays partials_list = RecursiveArrayTools.VectorOfArray(partials) - return map(((uᵢ, pᵢ),) -> dual_type(uᵢ, Partials(Tuple(pᵢ))), + return map(((uᵢ, pᵢ),) -> DT(uᵢ, Partials(Tuple(pᵢ))), zip(u, partials_list[i, :] for i in 1:length(partials_list.u[1]))) end @@ -167,7 +166,7 @@ function __dual_init( alias = alias, abstol = abstol, reltol = reltol, maxiters = maxiters, verbose = verbose, Pl = Pl, Pr = Pr, assumptions = assumptions, sensealg = sensealg, u0 = new_u0, kwargs...) - return DualLinearCache(non_partial_cache, dual_type, ∂_A, ∂_b, + return DualLinearCache{dual_type}(non_partial_cache, ∂_A, ∂_b, !isnothing(∂_b) ? zero.(∂_b) : ∂_b, A, b, zeros(dual_type, length(b))) end @@ -176,11 +175,11 @@ function SciMLBase.solve!(cache::DualLinearCache, args...; kwargs...) end function SciMLBase.solve!( - cache::DualLinearCache, alg::SciMLLinearSolveAlgorithm, args...; kwargs...) + cache::DualLinearCache{DT}, alg::SciMLLinearSolveAlgorithm, args...; kwargs...) where {DT <: ForwardDiff.Dual} sol, partials = linearsolve_forwarddiff_solve( cache::DualLinearCache, cache.alg, args...; kwargs...) - dual_sol = linearsolve_dual_solution(sol.u, partials, cache.dual_type) + dual_sol = linearsolve_dual_solution(sol.u, partials, cache) if cache.dual_u isa AbstractArray cache.dual_u[:] = dual_sol