@@ -34,9 +34,8 @@ const DualBLinearProblem = LinearProblem{
34
34
const DualAbstractLinearProblem = Union{
35
35
DualLinearProblem, DualALinearProblem, DualBLinearProblem}
36
36
37
- LinearSolve. @concrete mutable struct DualLinearCache
37
+ LinearSolve. @concrete mutable struct DualLinearCache{DT <: Dual }
38
38
linear_cache
39
- dual_type
40
39
41
40
partials_A
42
41
partials_b
@@ -54,7 +53,13 @@ function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwa
54
53
primal_b = copy (cache. linear_cache. b)
55
54
uu = sol. u
56
55
57
- primal_sol = deepcopy (sol)
56
+ primal_sol = (;
57
+ u = recursivecopy (sol. u),
58
+ resid = recursivecopy (sol. resid),
59
+ retcode = recursivecopy (sol. retcode),
60
+ iters = recursivecopy (sol. iters),
61
+ stats = recursivecopy (sol. stats)
62
+ )
58
63
59
64
# Solves Dual partials separately
60
65
∂_A = cache. partials_A
@@ -103,21 +108,15 @@ function xp_linsolve_rhs(
103
108
end
104
109
105
110
function linearsolve_dual_solution (
106
- u:: Number , partials, dual_type)
107
- return dual_type (u, partials)
108
- end
109
-
110
- function linearsolve_dual_solution (u:: Number , partials,
111
- dual_type:: Type{<:Dual{T, V, P}} ) where {T, V, P}
112
- # Handle single-level duals
113
- return dual_type (u, partials)
111
+ u:: Number , partials, cache:: DualLinearCache{DT} ) where {DT}
112
+ return DT (u, partials)
114
113
end
115
114
116
115
function linearsolve_dual_solution (u:: AbstractArray , partials,
117
- dual_type :: Type{<:Dual{T, V, P}} ) where {T, V, P }
116
+ cache :: DualLinearCache{DT} ) where {DT }
118
117
# Handle single-level duals for arrays
119
118
partials_list = RecursiveArrayTools. VectorOfArray (partials)
120
- return map (((uᵢ, pᵢ),) -> dual_type (uᵢ, Partials (Tuple (pᵢ))),
119
+ return map (((uᵢ, pᵢ),) -> DT (uᵢ, Partials (Tuple (pᵢ))),
121
120
zip (u, partials_list[i, :] for i in 1 : length (partials_list. u[1 ])))
122
121
end
123
122
@@ -167,7 +166,7 @@ function __dual_init(
167
166
alias = alias, abstol = abstol, reltol = reltol,
168
167
maxiters = maxiters, verbose = verbose, Pl = Pl, Pr = Pr, assumptions = assumptions,
169
168
sensealg = sensealg, u0 = new_u0, kwargs... )
170
- return DualLinearCache (non_partial_cache, dual_type , ∂_A, ∂_b,
169
+ return DualLinearCache {dual_type} (non_partial_cache, ∂_A, ∂_b,
171
170
! isnothing (∂_b) ? zero .(∂_b) : ∂_b, A, b, zeros (dual_type, length (b)))
172
171
end
173
172
@@ -176,11 +175,11 @@ function SciMLBase.solve!(cache::DualLinearCache, args...; kwargs...)
176
175
end
177
176
178
177
function SciMLBase. solve! (
179
- cache:: DualLinearCache , alg:: SciMLLinearSolveAlgorithm , args... ; kwargs... )
178
+ cache:: DualLinearCache{DT} , alg:: SciMLLinearSolveAlgorithm , args... ; kwargs... ) where {DT <: ForwardDiff.Dual }
180
179
sol,
181
180
partials = linearsolve_forwarddiff_solve (
182
181
cache:: DualLinearCache , cache. alg, args... ; kwargs... )
183
- dual_sol = linearsolve_dual_solution (sol. u, partials, cache. dual_type )
182
+ dual_sol = linearsolve_dual_solution (sol. u, partials, cache)
184
183
185
184
if cache. dual_u isa AbstractArray
186
185
cache. dual_u[:] = dual_sol
0 commit comments