Skip to content

Commit b307bb0

Browse files
authored
Get rid of deepcopy, put Dual types in type system (#724)
* use recursivecopy instead of deepcopy * put Dual types in type system * clean up
1 parent f1d92b1 commit b307bb0

File tree

1 file changed

+15
-16
lines changed

1 file changed

+15
-16
lines changed

ext/LinearSolveForwardDiffExt.jl

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,8 @@ const DualBLinearProblem = LinearProblem{
3434
const DualAbstractLinearProblem = Union{
3535
DualLinearProblem, DualALinearProblem, DualBLinearProblem}
3636

37-
LinearSolve.@concrete mutable struct DualLinearCache
37+
LinearSolve.@concrete mutable struct DualLinearCache{DT <: Dual}
3838
linear_cache
39-
dual_type
4039

4140
partials_A
4241
partials_b
@@ -54,7 +53,13 @@ function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwa
5453
primal_b = copy(cache.linear_cache.b)
5554
uu = sol.u
5655

57-
primal_sol = deepcopy(sol)
56+
primal_sol = (;
57+
u = recursivecopy(sol.u),
58+
resid = recursivecopy(sol.resid),
59+
retcode = recursivecopy(sol.retcode),
60+
iters = recursivecopy(sol.iters),
61+
stats = recursivecopy(sol.stats)
62+
)
5863

5964
# Solves Dual partials separately
6065
∂_A = cache.partials_A
@@ -103,21 +108,15 @@ function xp_linsolve_rhs(
103108
end
104109

105110
function linearsolve_dual_solution(
106-
u::Number, partials, dual_type)
107-
return dual_type(u, partials)
108-
end
109-
110-
function linearsolve_dual_solution(u::Number, partials,
111-
dual_type::Type{<:Dual{T, V, P}}) where {T, V, P}
112-
# Handle single-level duals
113-
return dual_type(u, partials)
111+
u::Number, partials, cache::DualLinearCache{DT}) where {DT}
112+
return DT(u, partials)
114113
end
115114

116115
function linearsolve_dual_solution(u::AbstractArray, partials,
117-
dual_type::Type{<:Dual{T, V, P}}) where {T, V, P}
116+
cache::DualLinearCache{DT}) where {DT}
118117
# Handle single-level duals for arrays
119118
partials_list = RecursiveArrayTools.VectorOfArray(partials)
120-
return map(((uᵢ, pᵢ),) -> dual_type(uᵢ, Partials(Tuple(pᵢ))),
119+
return map(((uᵢ, pᵢ),) -> DT(uᵢ, Partials(Tuple(pᵢ))),
121120
zip(u, partials_list[i, :] for i in 1:length(partials_list.u[1])))
122121
end
123122

@@ -167,7 +166,7 @@ function __dual_init(
167166
alias = alias, abstol = abstol, reltol = reltol,
168167
maxiters = maxiters, verbose = verbose, Pl = Pl, Pr = Pr, assumptions = assumptions,
169168
sensealg = sensealg, u0 = new_u0, kwargs...)
170-
return DualLinearCache(non_partial_cache, dual_type, ∂_A, ∂_b,
169+
return DualLinearCache{dual_type}(non_partial_cache, ∂_A, ∂_b,
171170
!isnothing(∂_b) ? zero.(∂_b) : ∂_b, A, b, zeros(dual_type, length(b)))
172171
end
173172

@@ -176,11 +175,11 @@ function SciMLBase.solve!(cache::DualLinearCache, args...; kwargs...)
176175
end
177176

178177
function SciMLBase.solve!(
179-
cache::DualLinearCache, alg::SciMLLinearSolveAlgorithm, args...; kwargs...)
178+
cache::DualLinearCache{DT}, alg::SciMLLinearSolveAlgorithm, args...; kwargs...) where {DT <: ForwardDiff.Dual}
180179
sol,
181180
partials = linearsolve_forwarddiff_solve(
182181
cache::DualLinearCache, cache.alg, args...; kwargs...)
183-
dual_sol = linearsolve_dual_solution(sol.u, partials, cache.dual_type)
182+
dual_sol = linearsolve_dual_solution(sol.u, partials, cache)
184183

185184
if cache.dual_u isa AbstractArray
186185
cache.dual_u[:] = dual_sol

0 commit comments

Comments
 (0)