From 0a156193b8e632ca8c2c3da63ae4e106f308e877 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Wed, 13 Aug 2025 11:57:52 -0400 Subject: [PATCH 1/3] use recursivecopy instead of deepcopy --- ext/LinearSolveForwardDiffExt.jl | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/ext/LinearSolveForwardDiffExt.jl b/ext/LinearSolveForwardDiffExt.jl index 01dabe7a2..f2fbfa938 100644 --- a/ext/LinearSolveForwardDiffExt.jl +++ b/ext/LinearSolveForwardDiffExt.jl @@ -54,7 +54,13 @@ function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwa primal_b = copy(cache.linear_cache.b) uu = sol.u - primal_sol = deepcopy(sol) + primal_sol = (; + u = recursivecopy(sol.u), + resid = recursivecopy(sol.resid), + retcode = recursivecopy(sol.retcode), + iters = recursivecopy(sol.iters), + stats = recursivecopy(sol.stats) + ) # Solves Dual partials separately ∂_A = cache.partials_A From 90c4485c1847469e036746accfdb878fe4f63cc2 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Wed, 13 Aug 2025 12:24:33 -0400 Subject: [PATCH 2/3] put Dual types in type system --- ext/LinearSolveForwardDiffExt.jl | 40 +++++++++++++++++++++----------- 1 file changed, 26 insertions(+), 14 deletions(-) diff --git a/ext/LinearSolveForwardDiffExt.jl b/ext/LinearSolveForwardDiffExt.jl index f2fbfa938..3fac6efeb 100644 --- a/ext/LinearSolveForwardDiffExt.jl +++ b/ext/LinearSolveForwardDiffExt.jl @@ -34,9 +34,21 @@ const DualBLinearProblem = LinearProblem{ const DualAbstractLinearProblem = Union{ DualLinearProblem, DualALinearProblem, DualBLinearProblem} -LinearSolve.@concrete mutable struct DualLinearCache +# LinearSolve.@concrete mutable struct DualLinearCache +# linear_cache +# dual_type + +# partials_A +# partials_b +# partials_u + +# dual_A +# dual_b +# dual_u +# end + +LinearSolve.@concrete mutable struct DualLinearCache{DT <: Dual} linear_cache - dual_type partials_A partials_b @@ -109,21 +121,21 @@ function xp_linsolve_rhs( end function linearsolve_dual_solution( - u::Number, partials, dual_type) - return dual_type(u, partials) + u::Number, partials, cache::DualLinearCache{DT}) where {DT} + return DT(u, partials) end -function linearsolve_dual_solution(u::Number, partials, - dual_type::Type{<:Dual{T, V, P}}) where {T, V, P} - # Handle single-level duals - return dual_type(u, partials) -end +# function linearsolve_dual_solution(u::Number, partials, +# dual_type::Type{<:Dual{T, V, P}}) where {T, V, P} +# # Handle single-level duals +# return dual_type(u, partials) +# end function linearsolve_dual_solution(u::AbstractArray, partials, - dual_type::Type{<:Dual{T, V, P}}) where {T, V, P} + cache::DualLinearCache{DT}) where {DT} # Handle single-level duals for arrays partials_list = RecursiveArrayTools.VectorOfArray(partials) - return map(((uᵢ, pᵢ),) -> dual_type(uᵢ, Partials(Tuple(pᵢ))), + return map(((uᵢ, pᵢ),) -> DT(uᵢ, Partials(Tuple(pᵢ))), zip(u, partials_list[i, :] for i in 1:length(partials_list.u[1]))) end @@ -173,7 +185,7 @@ function __dual_init( alias = alias, abstol = abstol, reltol = reltol, maxiters = maxiters, verbose = verbose, Pl = Pl, Pr = Pr, assumptions = assumptions, sensealg = sensealg, u0 = new_u0, kwargs...) - return DualLinearCache(non_partial_cache, dual_type, ∂_A, ∂_b, + return DualLinearCache{dual_type}(non_partial_cache, ∂_A, ∂_b, !isnothing(∂_b) ? zero.(∂_b) : ∂_b, A, b, zeros(dual_type, length(b))) end @@ -182,11 +194,11 @@ function SciMLBase.solve!(cache::DualLinearCache, args...; kwargs...) end function SciMLBase.solve!( - cache::DualLinearCache, alg::SciMLLinearSolveAlgorithm, args...; kwargs...) + cache::DualLinearCache{DT}, alg::SciMLLinearSolveAlgorithm, args...; kwargs...) where {DT <: ForwardDiff.Dual} sol, partials = linearsolve_forwarddiff_solve( cache::DualLinearCache, cache.alg, args...; kwargs...) - dual_sol = linearsolve_dual_solution(sol.u, partials, cache.dual_type) + dual_sol = linearsolve_dual_solution(sol.u, partials, cache) if cache.dual_u isa AbstractArray cache.dual_u[:] = dual_sol From a5b56ecc8bd16cc06f221fa13922749caaf057c3 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Wed, 13 Aug 2025 12:37:14 -0400 Subject: [PATCH 3/3] clean up --- ext/LinearSolveForwardDiffExt.jl | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/ext/LinearSolveForwardDiffExt.jl b/ext/LinearSolveForwardDiffExt.jl index 3fac6efeb..fb0ec730f 100644 --- a/ext/LinearSolveForwardDiffExt.jl +++ b/ext/LinearSolveForwardDiffExt.jl @@ -34,19 +34,6 @@ const DualBLinearProblem = LinearProblem{ const DualAbstractLinearProblem = Union{ DualLinearProblem, DualALinearProblem, DualBLinearProblem} -# LinearSolve.@concrete mutable struct DualLinearCache -# linear_cache -# dual_type - -# partials_A -# partials_b -# partials_u - -# dual_A -# dual_b -# dual_u -# end - LinearSolve.@concrete mutable struct DualLinearCache{DT <: Dual} linear_cache @@ -125,12 +112,6 @@ function linearsolve_dual_solution( return DT(u, partials) end -# function linearsolve_dual_solution(u::Number, partials, -# dual_type::Type{<:Dual{T, V, P}}) where {T, V, P} -# # Handle single-level duals -# return dual_type(u, partials) -# end - function linearsolve_dual_solution(u::AbstractArray, partials, cache::DualLinearCache{DT}) where {DT} # Handle single-level duals for arrays