@@ -8,45 +8,31 @@ using ForwardDiff: Dual, Partials
88using SciMLBase
99using RecursiveArrayTools
1010
11-
12- # Define type for non-nested dual numbers
13- const SingleDual{T, V, P} = Dual{T, V, P} where {T, V <: Number , P}
14-
15- # Define type for nested dual numbers
16- const NestedDual{T, V, P} = Dual{T, V, P} where {T, V <: Dual , P}
17-
18- const SingleDualLinearProblem = LinearProblem{
11+ const DualLinearProblem = LinearProblem{
1912 <: Union{Number, <:AbstractArray, Nothing} , iip,
20- <: Union{<:SingleDual, <:AbstractArray{<:SingleDual }} ,
21- <: Union{<:SingleDual, <:AbstractArray{<:SingleDual }} ,
13+ <: Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P} }} ,
14+ <: Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P} }} ,
2215 <: Any
23- } where {iip}
24-
25- const NestedDualLinearProblem = LinearProblem{
26- <: Union{Number, <:AbstractArray, Nothing} , iip,
27- <: Union{<:NestedDual, <:AbstractArray{<:NestedDual}} ,
28- <: Union{<:NestedDual, <:AbstractArray{<:NestedDual}} ,
29- <: Any
30- } where {iip}
16+ } where {iip, T, V, P}
3117
3218const DualALinearProblem = LinearProblem{
3319 <: Union{Number, <:AbstractArray, Nothing} ,
3420 iip,
35- <: Union{<:SingleDual, <:AbstractArray{<:SingleDual }} ,
21+ <: Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P} }} ,
3622 <: Union{Number, <:AbstractArray} ,
3723 <: Any
38- } where {iip}
24+ } where {iip, T, V, P }
3925
4026const DualBLinearProblem = LinearProblem{
4127 <: Union{Number, <:AbstractArray, Nothing} ,
4228 iip,
4329 <: Union{Number, <:AbstractArray} ,
44- <: Union{<:SingleDual, <:AbstractArray{<:SingleDual }} ,
30+ <: Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P} }} ,
4531 <: Any
46- } where {iip}
32+ } where {iip, T, V, P }
4733
4834const DualAbstractLinearProblem = Union{
49- SingleDualLinearProblem , DualALinearProblem, DualBLinearProblem} # , NestedDualLinearProblem }
35+ DualLinearProblem , DualALinearProblem, DualBLinearProblem}
5036
5137LinearSolve. @concrete mutable struct DualLinearCache
5238 linear_cache
6349
6450function linearsolve_forwarddiff_solve (cache:: DualLinearCache , alg, args... ; kwargs... )
6551 # Solve the primal problem
52+ @info " here"
6653 dual_u0 = copy (cache. linear_cache. u)
6754 sol = solve! (cache. linear_cache, alg, args... ; kwargs... )
6855 primal_b = copy (cache. linear_cache. b)
@@ -122,47 +109,19 @@ function linearsolve_dual_solution(
122109end
123110
124111function linearsolve_dual_solution (u:: Number , partials,
125- dual_type:: Type{<:Dual{T, V, P}} ) where {T, V <: AbstractFloat , P}
112+ dual_type:: Type{<:Dual{T, V, P}} ) where {T, V, P}
126113 # Handle single-level duals
127114 return dual_type (u, partials)
128115end
129116
130117function linearsolve_dual_solution (u:: AbstractArray , partials,
131- dual_type:: Type{<:Dual{T, V, P}} ) where {T, V <: AbstractFloat , P}
118+ dual_type:: Type{<:Dual{T, V, P}} ) where {T, V, P}
132119 # Handle single-level duals for arrays
133120 partials_list = RecursiveArrayTools. VectorOfArray (partials)
134121 return map (((uᵢ, pᵢ),) -> dual_type (uᵢ, Partials (Tuple (pᵢ))),
135122 zip (u, partials_list[i, :] for i in 1 : length (partials_list. u[1 ])))
136123end
137124
138-
139- function linearsolve_dual_solution (
140- u:: Number , partials, dual_type:: Type{<:Dual{T, V, P}} ) where {T, V <: Dual , P}
141- # Handle nested duals - recursive case
142- # For nested duals, u itself could be a dual number with its own partials
143- inner_dual_type = V
144- outer_tag_type = T
145-
146- # Reconstruct the nested dual by first building the inner dual, then the outer one
147- inner_dual = u # u is already a dual for the inner level
148-
149- # Create outer dual with the inner dual as its value
150- return Dual {outer_tag_type, typeof(inner_dual), P} (inner_dual, partials)
151- end
152-
153- function linearsolve_dual_solution (u:: AbstractArray , partials,
154- dual_type:: Type{<:Dual{T, V, P}} ) where {T, V <: Dual , P}
155- # Handle nested duals for arrays - recursive case
156- inner_dual_type = V
157- outer_tag_type = T
158-
159- partials_list = RecursiveArrayTools. VectorOfArray (partials)
160-
161- # For nested duals, each element of u could be a dual number with its own partials
162- return map (((uᵢ, pᵢ),) -> Dual {outer_tag_type, typeof(uᵢ), P} (uᵢ, Partials (Tuple (pᵢ))),
163- zip (u, partials_list[i, :] for i in 1 : length (partials_list[1 ])))
164- end
165-
166125function SciMLBase. init (
167126 prob:: DualAbstractLinearProblem , alg:: LinearSolve.SciMLLinearSolveAlgorithm ,
168127 args... ;
0 commit comments