Skip to content

Commit 8bae0f7

Browse files
committed
proper RAT indexing
1 parent 6ac7bf1 commit 8bae0f7

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

ext/LinearSolveForwardDiffExt.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ using RecursiveArrayTools
1010

1111

1212
# Define type for non-nested dual numbers
13-
const SingleDual{T, V, P} = Dual{T, V, P} where {T, V <:Number , P}
13+
const SingleDual{T, V, P} = Dual{T, V, P} where {T, V <:AbstractFloat , P}
1414

1515
# Define type for nested dual numbers
1616
const NestedDual{T, V, P} = Dual{T, V, P} where {T, V <:Dual, P}
@@ -46,7 +46,7 @@ const DualBLinearProblem = LinearProblem{
4646
} where {iip}
4747

4848
const DualAbstractLinearProblem = Union{
49-
SingleDualLinearProblem, DualALinearProblem, DualBLinearProblem, NestedDualLinearProblem}
49+
SingleDualLinearProblem, DualALinearProblem, DualBLinearProblem}#, NestedDualLinearProblem}
5050

5151
LinearSolve.@concrete mutable struct DualLinearCache
5252
linear_cache
@@ -132,7 +132,7 @@ function linearsolve_dual_solution(u::AbstractArray, partials,
132132
# Handle single-level duals for arrays
133133
partials_list = RecursiveArrayTools.VectorOfArray(partials)
134134
return map(((uᵢ, pᵢ),) -> dual_type(uᵢ, Partials(Tuple(pᵢ))),
135-
zip(u, partials_list[i, :] for i in 1:length(partials_list[1])))
135+
zip(u, partials_list.u[i, :] for i in 1:length(partials_list.u[1])))
136136
end
137137

138138

0 commit comments

Comments
 (0)