@@ -8,45 +8,31 @@ using ForwardDiff: Dual, Partials
8
8
using SciMLBase
9
9
using RecursiveArrayTools
10
10
11
-
12
- # Define type for non-nested dual numbers
13
- const SingleDual{T, V, P} = Dual{T, V, P} where {T, V <: Number , P}
14
-
15
- # Define type for nested dual numbers
16
- const NestedDual{T, V, P} = Dual{T, V, P} where {T, V <: Dual , P}
17
-
18
- const SingleDualLinearProblem = LinearProblem{
11
+ const DualLinearProblem = LinearProblem{
19
12
<: Union{Number, <:AbstractArray, Nothing} , iip,
20
- <: Union{<:SingleDual, <:AbstractArray{<:SingleDual }} ,
21
- <: Union{<:SingleDual, <:AbstractArray{<:SingleDual }} ,
13
+ <: Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P} }} ,
14
+ <: Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P} }} ,
22
15
<: Any
23
- } where {iip}
24
-
25
- const NestedDualLinearProblem = LinearProblem{
26
- <: Union{Number, <:AbstractArray, Nothing} , iip,
27
- <: Union{<:NestedDual, <:AbstractArray{<:NestedDual}} ,
28
- <: Union{<:NestedDual, <:AbstractArray{<:NestedDual}} ,
29
- <: Any
30
- } where {iip}
16
+ } where {iip, T, V, P}
31
17
32
18
const DualALinearProblem = LinearProblem{
33
19
<: Union{Number, <:AbstractArray, Nothing} ,
34
20
iip,
35
- <: Union{<:SingleDual, <:AbstractArray{<:SingleDual }} ,
21
+ <: Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P} }} ,
36
22
<: Union{Number, <:AbstractArray} ,
37
23
<: Any
38
- } where {iip}
24
+ } where {iip, T, V, P }
39
25
40
26
const DualBLinearProblem = LinearProblem{
41
27
<: Union{Number, <:AbstractArray, Nothing} ,
42
28
iip,
43
29
<: Union{Number, <:AbstractArray} ,
44
- <: Union{<:SingleDual, <:AbstractArray{<:SingleDual }} ,
30
+ <: Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P} }} ,
45
31
<: Any
46
- } where {iip}
32
+ } where {iip, T, V, P }
47
33
48
34
const DualAbstractLinearProblem = Union{
49
- SingleDualLinearProblem , DualALinearProblem, DualBLinearProblem} # , NestedDualLinearProblem }
35
+ DualLinearProblem , DualALinearProblem, DualBLinearProblem}
50
36
51
37
LinearSolve. @concrete mutable struct DualLinearCache
52
38
linear_cache
63
49
64
50
function linearsolve_forwarddiff_solve (cache:: DualLinearCache , alg, args... ; kwargs... )
65
51
# Solve the primal problem
52
+ @info " here"
66
53
dual_u0 = copy (cache. linear_cache. u)
67
54
sol = solve! (cache. linear_cache, alg, args... ; kwargs... )
68
55
primal_b = copy (cache. linear_cache. b)
@@ -122,47 +109,19 @@ function linearsolve_dual_solution(
122
109
end
123
110
124
111
function linearsolve_dual_solution (u:: Number , partials,
125
- dual_type:: Type{<:Dual{T, V, P}} ) where {T, V <: AbstractFloat , P}
112
+ dual_type:: Type{<:Dual{T, V, P}} ) where {T, V, P}
126
113
# Handle single-level duals
127
114
return dual_type (u, partials)
128
115
end
129
116
130
117
function linearsolve_dual_solution (u:: AbstractArray , partials,
131
- dual_type:: Type{<:Dual{T, V, P}} ) where {T, V <: AbstractFloat , P}
118
+ dual_type:: Type{<:Dual{T, V, P}} ) where {T, V, P}
132
119
# Handle single-level duals for arrays
133
120
partials_list = RecursiveArrayTools. VectorOfArray (partials)
134
121
return map (((uᵢ, pᵢ),) -> dual_type (uᵢ, Partials (Tuple (pᵢ))),
135
122
zip (u, partials_list[i, :] for i in 1 : length (partials_list. u[1 ])))
136
123
end
137
124
138
-
139
- function linearsolve_dual_solution (
140
- u:: Number , partials, dual_type:: Type{<:Dual{T, V, P}} ) where {T, V <: Dual , P}
141
- # Handle nested duals - recursive case
142
- # For nested duals, u itself could be a dual number with its own partials
143
- inner_dual_type = V
144
- outer_tag_type = T
145
-
146
- # Reconstruct the nested dual by first building the inner dual, then the outer one
147
- inner_dual = u # u is already a dual for the inner level
148
-
149
- # Create outer dual with the inner dual as its value
150
- return Dual {outer_tag_type, typeof(inner_dual), P} (inner_dual, partials)
151
- end
152
-
153
- function linearsolve_dual_solution (u:: AbstractArray , partials,
154
- dual_type:: Type{<:Dual{T, V, P}} ) where {T, V <: Dual , P}
155
- # Handle nested duals for arrays - recursive case
156
- inner_dual_type = V
157
- outer_tag_type = T
158
-
159
- partials_list = RecursiveArrayTools. VectorOfArray (partials)
160
-
161
- # For nested duals, each element of u could be a dual number with its own partials
162
- return map (((uᵢ, pᵢ),) -> Dual {outer_tag_type, typeof(uᵢ), P} (uᵢ, Partials (Tuple (pᵢ))),
163
- zip (u, partials_list[i, :] for i in 1 : length (partials_list[1 ])))
164
- end
165
-
166
125
function SciMLBase. init (
167
126
prob:: DualAbstractLinearProblem , alg:: LinearSolve.SciMLLinearSolveAlgorithm ,
168
127
args... ;
0 commit comments