@@ -10,7 +10,7 @@ using RecursiveArrayTools
10
10
11
11
12
12
# Define type for non-nested dual numbers
13
- const SingleDual{T, V, P} = Dual{T, V, P} where {T, V <: Number , P}
13
+ const SingleDual{T, V, P} = Dual{T, V, P} where {T, V <: AbstractFloat , P}
14
14
15
15
# Define type for nested dual numbers
16
16
const NestedDual{T, V, P} = Dual{T, V, P} where {T, V <: Dual , P}
@@ -46,7 +46,7 @@ const DualBLinearProblem = LinearProblem{
46
46
} where {iip}
47
47
48
48
const DualAbstractLinearProblem = Union{
49
- SingleDualLinearProblem, DualALinearProblem, DualBLinearProblem, NestedDualLinearProblem}
49
+ SingleDualLinearProblem, DualALinearProblem, DualBLinearProblem} # , NestedDualLinearProblem}
50
50
51
51
LinearSolve. @concrete mutable struct DualLinearCache
52
52
linear_cache
@@ -132,7 +132,7 @@ function linearsolve_dual_solution(u::AbstractArray, partials,
132
132
# Handle single-level duals for arrays
133
133
partials_list = RecursiveArrayTools. VectorOfArray (partials)
134
134
return map (((uᵢ, pᵢ),) -> dual_type (uᵢ, Partials (Tuple (pᵢ))),
135
- zip (u, partials_list[i, :] for i in 1 : length (partials_list[1 ])))
135
+ zip (u, partials_list. u [i, :] for i in 1 : length (partials_list. u [1 ])))
136
136
end
137
137
138
138
0 commit comments