@@ -55,21 +55,16 @@ LinearSolve.@concrete mutable struct DualLinearCache{DT}
5555 dual_u
5656end
5757
58- function linearsolve_forwarddiff_solve (cache:: DualLinearCache , alg, args... ; kwargs... )
58+ function linearsolve_forwarddiff_solve! (cache:: DualLinearCache , alg, args... ; kwargs... )
5959 # Solve the primal problem
6060 cache. dual_u0_cache .= cache. linear_cache. u
6161 sol = solve! (cache. linear_cache, alg, args... ; kwargs... )
6262
6363 cache. primal_b_cache .= cache. linear_cache. b
6464 uu = sol. u
6565
66- primal_sol = (;
67- u = recursivecopy (sol. u),
68- resid = recursivecopy (sol. resid),
69- retcode = recursivecopy (sol. retcode),
70- iters = recursivecopy (sol. iters),
71- stats = recursivecopy (sol. stats)
72- )
66+ # Store solution metadata without copying - we'll return this
67+ primal_sol = sol
7368
7469 # Solves Dual partials separately
7570 ∂_A = cache. partials_A
@@ -89,9 +84,7 @@ function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwa
8984 # Reset to the original `b` and `u`, users will expect that `b` doesn't change if they don't tell it to
9085 cache. linear_cache. b .= cache. primal_b_cache
9186
92- partial_sols = rhs_list
93-
94- primal_sol, partial_sols
87+ return primal_sol
9588end
9689
9790function xp_linsolve_rhs! (uu, ∂_A:: Union{<:Partials, <:AbstractArray{<:Partials}} ,
153146function linearsolve_dual_solution (u:: AbstractArray , partials,
154147 cache:: DualLinearCache{DT} ) where {T, V, N, DT <: Dual{T,V,N} }
155148 # Optimized in-place version that reuses cache.dual_u
156- linearsolve_dual_solution! (cache. dual_u, u, partials)
157- return cache. dual_u
149+ linearsolve_dual_solution! (getfield ( cache, : dual_u) , u, partials)
150+ return getfield ( cache, : dual_u)
158151end
159152
160153function linearsolve_dual_solution! (dual_u:: AbstractArray{DT} , u:: AbstractArray , partials) where {T, V, N, DT <: Dual{T,V,N} }
@@ -254,23 +247,22 @@ function __dual_init(
254247end
255248
256249function SciMLBase. solve! (cache:: DualLinearCache , args... ; kwargs... )
257- solve! (cache, cache. alg, args... ; kwargs... )
250+ solve! (cache, getfield ( cache, :linear_cache ) . alg, args... ; kwargs... )
258251end
259252
260253function SciMLBase. solve! (
261254 cache:: DualLinearCache{DT} , alg:: SciMLLinearSolveAlgorithm , args... ; kwargs... ) where {DT <: ForwardDiff.Dual }
262- sol,
263- partials = linearsolve_forwarddiff_solve (
264- cache:: DualLinearCache , cache. alg, args... ; kwargs... )
265- dual_sol = linearsolve_dual_solution (sol. u, partials, cache)
255+ primal_sol = linearsolve_forwarddiff_solve! (
256+ cache:: DualLinearCache , getfield (cache, :linear_cache ). alg, args... ; kwargs... )
257+ dual_sol = linearsolve_dual_solution (getfield (cache,:linear_cache ). u, getfield (cache, :rhs_list ), cache)
266258
267259 # For scalars, we still need to assign since cache.dual_u might not be pre-allocated
268- if ! (cache. dual_u isa AbstractArray)
269- cache. dual_u = dual_sol
260+ if ! (getfield ( cache, : dual_u) isa AbstractArray)
261+ setfield! ( cache, : dual_u, dual_sol)
270262 end
271263
272264 return SciMLBase. build_linear_solution (
273- cache. alg, cache. dual_u, sol . resid, cache; sol . retcode, sol . iters, sol . stats
265+ getfield ( cache, :linear_cache ) . alg, getfield ( cache, : dual_u), primal_sol . resid, cache; primal_sol . retcode, primal_sol . iters, primal_sol . stats
274266 )
275267end
276268
0 commit comments