@@ -34,9 +34,21 @@ const DualBLinearProblem = LinearProblem{
34
34
const DualAbstractLinearProblem = Union{
35
35
DualLinearProblem, DualALinearProblem, DualBLinearProblem}
36
36
37
- LinearSolve. @concrete mutable struct DualLinearCache
37
+ # LinearSolve.@concrete mutable struct DualLinearCache
38
+ # linear_cache
39
+ # dual_type
40
+
41
+ # partials_A
42
+ # partials_b
43
+ # partials_u
44
+
45
+ # dual_A
46
+ # dual_b
47
+ # dual_u
48
+ # end
49
+
50
+ LinearSolve. @concrete mutable struct DualLinearCache{DT <: Dual }
38
51
linear_cache
39
- dual_type
40
52
41
53
partials_A
42
54
partials_b
@@ -109,21 +121,21 @@ function xp_linsolve_rhs(
109
121
end
110
122
111
123
function linearsolve_dual_solution (
112
- u:: Number , partials, dual_type)
113
- return dual_type (u, partials)
124
+ u:: Number , partials, cache :: DualLinearCache{DT} ) where {DT}
125
+ return DT (u, partials)
114
126
end
115
127
116
- function linearsolve_dual_solution (u:: Number , partials,
117
- dual_type:: Type{<:Dual{T, V, P}} ) where {T, V, P}
118
- # Handle single-level duals
119
- return dual_type (u, partials)
120
- end
128
+ # function linearsolve_dual_solution(u::Number, partials,
129
+ # dual_type::Type{<:Dual{T, V, P}}) where {T, V, P}
130
+ # # Handle single-level duals
131
+ # return dual_type(u, partials)
132
+ # end
121
133
122
134
function linearsolve_dual_solution (u:: AbstractArray , partials,
123
- dual_type :: Type{<:Dual{T, V, P}} ) where {T, V, P }
135
+ cache :: DualLinearCache{DT} ) where {DT }
124
136
# Handle single-level duals for arrays
125
137
partials_list = RecursiveArrayTools. VectorOfArray (partials)
126
- return map (((uᵢ, pᵢ),) -> dual_type (uᵢ, Partials (Tuple (pᵢ))),
138
+ return map (((uᵢ, pᵢ),) -> DT (uᵢ, Partials (Tuple (pᵢ))),
127
139
zip (u, partials_list[i, :] for i in 1 : length (partials_list. u[1 ])))
128
140
end
129
141
@@ -173,7 +185,7 @@ function __dual_init(
173
185
alias = alias, abstol = abstol, reltol = reltol,
174
186
maxiters = maxiters, verbose = verbose, Pl = Pl, Pr = Pr, assumptions = assumptions,
175
187
sensealg = sensealg, u0 = new_u0, kwargs... )
176
- return DualLinearCache (non_partial_cache, dual_type , ∂_A, ∂_b,
188
+ return DualLinearCache {dual_type} (non_partial_cache, ∂_A, ∂_b,
177
189
! isnothing (∂_b) ? zero .(∂_b) : ∂_b, A, b, zeros (dual_type, length (b)))
178
190
end
179
191
@@ -182,11 +194,11 @@ function SciMLBase.solve!(cache::DualLinearCache, args...; kwargs...)
182
194
end
183
195
184
196
function SciMLBase. solve! (
185
- cache:: DualLinearCache , alg:: SciMLLinearSolveAlgorithm , args... ; kwargs... )
197
+ cache:: DualLinearCache{DT} , alg:: SciMLLinearSolveAlgorithm , args... ; kwargs... ) where {DT <: ForwardDiff.Dual }
186
198
sol,
187
199
partials = linearsolve_forwarddiff_solve (
188
200
cache:: DualLinearCache , cache. alg, args... ; kwargs... )
189
- dual_sol = linearsolve_dual_solution (sol. u, partials, cache. dual_type )
201
+ dual_sol = linearsolve_dual_solution (sol. u, partials, cache)
190
202
191
203
if cache. dual_u isa AbstractArray
192
204
cache. dual_u[:] = dual_sol
0 commit comments