@@ -10,7 +10,7 @@ using RecursiveArrayTools
1010
1111
1212# Define type for non-nested dual numbers
13- const SingleDual{T, V, P} = Dual{T, V, P} where {T, V <: Number , P}
13+ const SingleDual{T, V, P} = Dual{T, V, P} where {T, V <: AbstractFloat , P}
1414
1515# Define type for nested dual numbers
1616const NestedDual{T, V, P} = Dual{T, V, P} where {T, V <: Dual , P}
@@ -46,7 +46,7 @@ const DualBLinearProblem = LinearProblem{
4646} where {iip}
4747
4848const DualAbstractLinearProblem = Union{
49- SingleDualLinearProblem, DualALinearProblem, DualBLinearProblem, NestedDualLinearProblem}
49+ SingleDualLinearProblem, DualALinearProblem, DualBLinearProblem} # , NestedDualLinearProblem}
5050
5151LinearSolve. @concrete mutable struct DualLinearCache
5252 linear_cache
@@ -132,7 +132,7 @@ function linearsolve_dual_solution(u::AbstractArray, partials,
132132 # Handle single-level duals for arrays
133133 partials_list = RecursiveArrayTools. VectorOfArray (partials)
134134 return map (((uᵢ, pᵢ),) -> dual_type (uᵢ, Partials (Tuple (pᵢ))),
135- zip (u, partials_list[i, :] for i in 1 : length (partials_list[1 ])))
135+ zip (u, partials_list. u [i, :] for i in 1 : length (partials_list. u [1 ])))
136136end
137137
138138
0 commit comments