Skip to content

Commit 90c4485

Browse files
committed
put Dual types in type system
1 parent 0a15619 commit 90c4485

File tree

1 file changed

+26
-14
lines changed

1 file changed

+26
-14
lines changed

ext/LinearSolveForwardDiffExt.jl

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,21 @@ const DualBLinearProblem = LinearProblem{
3434
const DualAbstractLinearProblem = Union{
3535
DualLinearProblem, DualALinearProblem, DualBLinearProblem}
3636

37-
LinearSolve.@concrete mutable struct DualLinearCache
37+
# LinearSolve.@concrete mutable struct DualLinearCache
38+
# linear_cache
39+
# dual_type
40+
41+
# partials_A
42+
# partials_b
43+
# partials_u
44+
45+
# dual_A
46+
# dual_b
47+
# dual_u
48+
# end
49+
50+
LinearSolve.@concrete mutable struct DualLinearCache{DT <: Dual}
3851
linear_cache
39-
dual_type
4052

4153
partials_A
4254
partials_b
@@ -109,21 +121,21 @@ function xp_linsolve_rhs(
109121
end
110122

111123
function linearsolve_dual_solution(
112-
u::Number, partials, dual_type)
113-
return dual_type(u, partials)
124+
u::Number, partials, cache::DualLinearCache{DT}) where {DT}
125+
return DT(u, partials)
114126
end
115127

116-
function linearsolve_dual_solution(u::Number, partials,
117-
dual_type::Type{<:Dual{T, V, P}}) where {T, V, P}
118-
# Handle single-level duals
119-
return dual_type(u, partials)
120-
end
128+
# function linearsolve_dual_solution(u::Number, partials,
129+
# dual_type::Type{<:Dual{T, V, P}}) where {T, V, P}
130+
# # Handle single-level duals
131+
# return dual_type(u, partials)
132+
# end
121133

122134
function linearsolve_dual_solution(u::AbstractArray, partials,
123-
dual_type::Type{<:Dual{T, V, P}}) where {T, V, P}
135+
cache::DualLinearCache{DT}) where {DT}
124136
# Handle single-level duals for arrays
125137
partials_list = RecursiveArrayTools.VectorOfArray(partials)
126-
return map(((uᵢ, pᵢ),) -> dual_type(uᵢ, Partials(Tuple(pᵢ))),
138+
return map(((uᵢ, pᵢ),) -> DT(uᵢ, Partials(Tuple(pᵢ))),
127139
zip(u, partials_list[i, :] for i in 1:length(partials_list.u[1])))
128140
end
129141

@@ -173,7 +185,7 @@ function __dual_init(
173185
alias = alias, abstol = abstol, reltol = reltol,
174186
maxiters = maxiters, verbose = verbose, Pl = Pl, Pr = Pr, assumptions = assumptions,
175187
sensealg = sensealg, u0 = new_u0, kwargs...)
176-
return DualLinearCache(non_partial_cache, dual_type, ∂_A, ∂_b,
188+
return DualLinearCache{dual_type}(non_partial_cache, ∂_A, ∂_b,
177189
!isnothing(∂_b) ? zero.(∂_b) : ∂_b, A, b, zeros(dual_type, length(b)))
178190
end
179191

@@ -182,11 +194,11 @@ function SciMLBase.solve!(cache::DualLinearCache, args...; kwargs...)
182194
end
183195

184196
function SciMLBase.solve!(
185-
cache::DualLinearCache, alg::SciMLLinearSolveAlgorithm, args...; kwargs...)
197+
cache::DualLinearCache{DT}, alg::SciMLLinearSolveAlgorithm, args...; kwargs...) where {DT <: ForwardDiff.Dual}
186198
sol,
187199
partials = linearsolve_forwarddiff_solve(
188200
cache::DualLinearCache, cache.alg, args...; kwargs...)
189-
dual_sol = linearsolve_dual_solution(sol.u, partials, cache.dual_type)
201+
dual_sol = linearsolve_dual_solution(sol.u, partials, cache)
190202

191203
if cache.dual_u isa AbstractArray
192204
cache.dual_u[:] = dual_sol

0 commit comments

Comments
 (0)