Skip to content

Commit d9ed726

Browse files
committed
get rid of unecessary things
1 parent 308bc76 commit d9ed726

File tree

1 file changed

+12
-53
lines changed

1 file changed

+12
-53
lines changed

ext/LinearSolveForwardDiffExt.jl

Lines changed: 12 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -8,45 +8,31 @@ using ForwardDiff: Dual, Partials
88
using SciMLBase
99
using RecursiveArrayTools
1010

11-
12-
# Define type for non-nested dual numbers
13-
const SingleDual{T, V, P} = Dual{T, V, P} where {T, V <:Number , P}
14-
15-
# Define type for nested dual numbers
16-
const NestedDual{T, V, P} = Dual{T, V, P} where {T, V <:Dual, P}
17-
18-
const SingleDualLinearProblem = LinearProblem{
11+
const DualLinearProblem = LinearProblem{
1912
<:Union{Number, <:AbstractArray, Nothing}, iip,
20-
<:Union{<:SingleDual, <:AbstractArray{<:SingleDual}},
21-
<:Union{<:SingleDual, <:AbstractArray{<:SingleDual}},
13+
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}},
14+
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}},
2215
<:Any
23-
} where {iip}
24-
25-
const NestedDualLinearProblem = LinearProblem{
26-
<:Union{Number, <:AbstractArray, Nothing}, iip,
27-
<:Union{<:NestedDual, <:AbstractArray{<:NestedDual}},
28-
<:Union{<:NestedDual, <:AbstractArray{<:NestedDual}},
29-
<:Any
30-
} where {iip}
16+
} where {iip, T, V, P}
3117

3218
const DualALinearProblem = LinearProblem{
3319
<:Union{Number, <:AbstractArray, Nothing},
3420
iip,
35-
<:Union{<:SingleDual, <:AbstractArray{<:SingleDual}},
21+
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}},
3622
<:Union{Number, <:AbstractArray},
3723
<:Any
38-
} where {iip}
24+
} where {iip, T, V, P}
3925

4026
const DualBLinearProblem = LinearProblem{
4127
<:Union{Number, <:AbstractArray, Nothing},
4228
iip,
4329
<:Union{Number, <:AbstractArray},
44-
<:Union{<:SingleDual, <:AbstractArray{<:SingleDual}},
30+
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}},
4531
<:Any
46-
} where {iip}
32+
} where {iip, T, V, P}
4733

4834
const DualAbstractLinearProblem = Union{
49-
SingleDualLinearProblem, DualALinearProblem, DualBLinearProblem}#, NestedDualLinearProblem}
35+
DualLinearProblem, DualALinearProblem, DualBLinearProblem}
5036

5137
LinearSolve.@concrete mutable struct DualLinearCache
5238
linear_cache
@@ -63,6 +49,7 @@ end
6349

6450
function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwargs...)
6551
# Solve the primal problem
52+
@info "here"
6653
dual_u0 = copy(cache.linear_cache.u)
6754
sol = solve!(cache.linear_cache, alg, args...; kwargs...)
6855
primal_b = copy(cache.linear_cache.b)
@@ -122,47 +109,19 @@ function linearsolve_dual_solution(
122109
end
123110

124111
function linearsolve_dual_solution(u::Number, partials,
125-
dual_type::Type{<:Dual{T, V, P}}) where {T, V <: AbstractFloat, P}
112+
dual_type::Type{<:Dual{T, V, P}}) where {T, V, P}
126113
# Handle single-level duals
127114
return dual_type(u, partials)
128115
end
129116

130117
function linearsolve_dual_solution(u::AbstractArray, partials,
131-
dual_type::Type{<:Dual{T, V, P}}) where {T, V <: AbstractFloat, P}
118+
dual_type::Type{<:Dual{T, V, P}}) where {T, V, P}
132119
# Handle single-level duals for arrays
133120
partials_list = RecursiveArrayTools.VectorOfArray(partials)
134121
return map(((uᵢ, pᵢ),) -> dual_type(uᵢ, Partials(Tuple(pᵢ))),
135122
zip(u, partials_list[i, :] for i in 1:length(partials_list.u[1])))
136123
end
137124

138-
139-
function linearsolve_dual_solution(
140-
u::Number, partials, dual_type::Type{<:Dual{T, V, P}}) where {T, V <: Dual, P}
141-
# Handle nested duals - recursive case
142-
# For nested duals, u itself could be a dual number with its own partials
143-
inner_dual_type = V
144-
outer_tag_type = T
145-
146-
# Reconstruct the nested dual by first building the inner dual, then the outer one
147-
inner_dual = u # u is already a dual for the inner level
148-
149-
# Create outer dual with the inner dual as its value
150-
return Dual{outer_tag_type, typeof(inner_dual), P}(inner_dual, partials)
151-
end
152-
153-
function linearsolve_dual_solution(u::AbstractArray, partials,
154-
dual_type::Type{<:Dual{T, V, P}}) where {T, V <: Dual, P}
155-
# Handle nested duals for arrays - recursive case
156-
inner_dual_type = V
157-
outer_tag_type = T
158-
159-
partials_list = RecursiveArrayTools.VectorOfArray(partials)
160-
161-
# For nested duals, each element of u could be a dual number with its own partials
162-
return map(((uᵢ, pᵢ),) -> Dual{outer_tag_type, typeof(uᵢ), P}(uᵢ, Partials(Tuple(pᵢ))),
163-
zip(u, partials_list[i, :] for i in 1:length(partials_list[1])))
164-
end
165-
166125
function SciMLBase.init(
167126
prob::DualAbstractLinearProblem, alg::LinearSolve.SciMLLinearSolveAlgorithm,
168127
args...;

0 commit comments

Comments
 (0)