Skip to content

Commit c9a6cb0

Browse files
committed
add utilities to SciMLBase
1 parent aa5aa77 commit c9a6cb0

File tree

3 files changed

+74
-341
lines changed

3 files changed

+74
-341
lines changed

ext/SciMLBaseForwardDiffExt.jl

Lines changed: 17 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,16 @@ module SciMLBaseForwardDiffExt
33
using SciMLBase, ForwardDiff
44
using ArrayInterface
55

6-
import SciMLBase: hasdualpromote, wrapfun_oop, wrapfun_iip, prob2dtmin, isdualtype, value
6+
import SciMLBase:
7+
wrapfun_oop, wrapfun_iip, isdualtype, value, DualEltypeChecker,
8+
AbstractTimeseriesSolution, NonlinearProblem, NonlinearLeastSquaresProblem,
9+
ODEProblem, SDEProblem, RODEProblem, DDEProblem, PDEProblem, DAEProblem,
10+
RecursiveArrayTools, totallength
11+
712

813

914
eltypedual(x) = eltype(x) <: ForwardDiff.Dual
1015
isdualtype(::Type{<:ForwardDiff.Dual}) = true
11-
const dualT = ForwardDiff.Dual{ForwardDiff.Tag{OrdinaryDiffEqTag, Float64}, Float64, 1}
12-
dualgen(::Type{T}) where {T} = ForwardDiff.Dual{ForwardDiff.Tag{OrdinaryDiffEqTag, T}, T, 1}
1316

1417
# Copy of the other prob2dtmin dispatch, just for optionality
1518
function prob2dtmin(tspan, ::ForwardDiff.Dual, use_end_time)
@@ -22,91 +25,16 @@ function prob2dtmin(tspan, ::ForwardDiff.Dual, use_end_time)
2225
end
2326
end
2427

25-
function hasdualpromote(u0, t::Number)
26-
hasmethod(ArrayInterface.promote_eltype,
27-
Tuple{Type{typeof(u0)}, Type{dualgen(eltype(u0))}}) &&
28-
hasmethod(promote_rule,
29-
Tuple{Type{eltype(u0)}, Type{dualgen(eltype(u0))}}) &&
30-
hasmethod(promote_rule,
31-
Tuple{Type{eltype(u0)}, Type{typeof(t)}})
32-
end
33-
34-
const NORECOMPILE_IIP_SUPPORTED_ARGS = (
35-
Tuple{Vector{Float64}, Vector{Float64},
36-
Vector{Float64}, Float64},
37-
Tuple{Vector{Float64}, Vector{Float64},
38-
SciMLBase.NullParameters, Float64})
39-
40-
const oop_arglists = (Tuple{Vector{Float64}, Vector{Float64}, Float64},
41-
Tuple{Vector{Float64}, SciMLBase.NullParameters, Float64},
42-
Tuple{Vector{Float64}, Vector{Float64}, dualT},
43-
Tuple{Vector{dualT}, Vector{Float64}, Float64},
44-
Tuple{Vector{dualT}, SciMLBase.NullParameters, Float64},
45-
Tuple{Vector{Float64}, SciMLBase.NullParameters, dualT})
46-
47-
const NORECOMPILE_OOP_SUPPORTED_ARGS = (Tuple{Vector{Float64},
48-
Vector{Float64}, Float64},
49-
Tuple{Vector{Float64},
50-
SciMLBase.NullParameters, Float64})
51-
const oop_returnlists = (Vector{Float64}, Vector{Float64},
52-
ntuple(x -> Vector{dualT}, length(oop_arglists) - 2)...)
53-
54-
function wrapfun_oop(ff, inputs::Tuple = ())
55-
if !isempty(inputs)
56-
IT = Tuple{map(typeof, inputs)...}
57-
if IT NORECOMPILE_OOP_SUPPORTED_ARGS
58-
throw(NoRecompileArgumentError(IT))
59-
end
60-
end
61-
FunctionWrappersWrappers.FunctionWrappersWrapper(ff, oop_arglists,
62-
oop_returnlists)
63-
end
64-
65-
function wrapfun_iip(ff,
66-
inputs::Tuple{T1, T2, T3, T4}) where {T1, T2, T3, T4}
67-
T = eltype(T2)
68-
dualT = dualgen(T)
69-
dualT1 = ArrayInterface.promote_eltype(T1, dualT)
70-
dualT2 = ArrayInterface.promote_eltype(T2, dualT)
71-
dualT4 = dualgen(promote_type(T, T4))
28+
# function hasdualpromote(u0, t::Number)
29+
# hasmethod(ArrayInterface.promote_eltype,
30+
# Tuple{Type{typeof(u0)}, Type{dualgen(eltype(u0))}}) &&
31+
# hasmethod(promote_rule,
32+
# Tuple{Type{eltype(u0)}, Type{dualgen(eltype(u0))}}) &&
33+
# hasmethod(promote_rule,
34+
# Tuple{Type{eltype(u0)}, Type{typeof(t)}})
35+
# end
7236

73-
iip_arglists = (Tuple{T1, T2, T3, T4},
74-
Tuple{dualT1, dualT2, T3, T4},
75-
Tuple{dualT1, T2, T3, dualT4},
76-
Tuple{dualT1, dualT2, T3, dualT4})
7737

78-
iip_returnlists = ntuple(x -> Nothing, 4)
79-
80-
fwt = map(iip_arglists, iip_returnlists) do A, R
81-
FunctionWrappersWrappers.FunctionWrappers.FunctionWrapper{R, A}(Void(ff))
82-
end
83-
FunctionWrappersWrappers.FunctionWrappersWrapper{typeof(fwt), false}(fwt)
84-
end
85-
86-
const iip_arglists_default = (
87-
Tuple{Vector{Float64}, Vector{Float64}, Vector{Float64},
88-
Float64},
89-
Tuple{Vector{Float64}, Vector{Float64},
90-
SciMLBase.NullParameters,
91-
Float64
92-
},
93-
Tuple{Vector{dualT}, Vector{Float64}, Vector{Float64}, dualT},
94-
Tuple{Vector{dualT}, Vector{dualT}, Vector{Float64}, dualT},
95-
Tuple{Vector{dualT}, Vector{dualT}, Vector{Float64}, Float64},
96-
Tuple{Vector{dualT}, Vector{dualT}, SciMLBase.NullParameters,
97-
Float64
98-
},
99-
Tuple{Vector{dualT}, Vector{Float64},
100-
SciMLBase.NullParameters, dualT
101-
})
102-
const iip_returnlists_default = ntuple(x -> Nothing, length(iip_arglists_default))
103-
104-
function wrapfun_iip(@nospecialize(ff))
105-
fwt = map(iip_arglists_default, iip_returnlists_default) do A, R
106-
FunctionWrappersWrappers.FunctionWrappers.FunctionWrapper{R, A}(Void(ff))
107-
end
108-
FunctionWrappersWrappers.FunctionWrappersWrapper{typeof(fwt), false}(fwt)
109-
end
11038

11139
promote_dual(::Type{T}, ::Type{T2}) where {T <: ForwardDiff.Dual, T2} = T
11240
function promote_dual(::Type{T},
@@ -497,9 +425,9 @@ unitfulvalue(x::Type{ForwardDiff.Dual{T, V, N}}) where {T, V, N} = V
497425
unitfulvalue(x::ForwardDiff.Dual) = unitfulvalue(ForwardDiff.value(x))
498426

499427
sse(x::ForwardDiff.Dual) = sse(ForwardDiff.value(x)) + sum(sse, ForwardDiff.partials(x))
500-
function DiffEqBase.totallength(x::ForwardDiff.Dual)
501-
return DiffEqBase.totallength(ForwardDiff.value(x)) +
502-
sum(DiffEqBase.totallength, ForwardDiff.partials(x))
428+
function SciMLBase.totallength(x::ForwardDiff.Dual)
429+
return SciMLBase.totallength(ForwardDiff.value(x)) +
430+
sum(SciMLBase.totallength, ForwardDiff.partials(x))
503431
end
504432

505433
end

ext/SciMLBaseReverseDiffExt.jl

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,54 +1,57 @@
11
module SciMLBaseReverseDiffExt
22

3-
function DiffEqBase.anyeltypedual(::Type{T},
3+
using SciMLBase
4+
using ReverseDiff
5+
6+
function SciMLBase.anyeltypedual(::Type{T},
47
::Type{Val{counter}} = Val{0}) where {counter} where {
58
V, D, N, VA, DA, T <: ReverseDiff.TrackedArray{V, D, N, VA, DA}}
6-
DiffEqBase.anyeltypedual(V, Val{counter})
9+
SciMLBase.anyeltypedual(V, Val{counter})
710
end
811

9-
DiffEqBase.value(x::Type{ReverseDiff.TrackedReal{V, D, O}}) where {V, D, O} = V
10-
function DiffEqBase.value(x::Type{
12+
SciMLBase.value(x::Type{ReverseDiff.TrackedReal{V, D, O}}) where {V, D, O} = V
13+
function SciMLBase.value(x::Type{
1114
ReverseDiff.TrackedArray{V, D, N, VA, DA},
1215
}) where {V, D,
1316
N, VA,
1417
DA}
1518
Array{V, N}
1619
end
17-
DiffEqBase.value(x::ReverseDiff.TrackedReal) = x.value
18-
DiffEqBase.value(x::ReverseDiff.TrackedArray) = x.value
20+
SciMLBase.value(x::ReverseDiff.TrackedReal) = x.value
21+
SciMLBase.value(x::ReverseDiff.TrackedArray) = x.value
1922

20-
DiffEqBase.unitfulvalue(x::Type{ReverseDiff.TrackedReal{V, D, O}}) where {V, D, O} = V
21-
function DiffEqBase.unitfulvalue(x::Type{
23+
SciMLBase.unitfulvalue(x::Type{ReverseDiff.TrackedReal{V, D, O}}) where {V, D, O} = V
24+
function SciMLBase.unitfulvalue(x::Type{
2225
ReverseDiff.TrackedArray{V, D, N, VA, DA},
2326
}) where {V, D,
2427
N, VA,
2528
DA}
2629
Array{V, N}
2730
end
28-
DiffEqBase.unitfulvalue(x::ReverseDiff.TrackedReal) = x.value
29-
DiffEqBase.unitfulvalue(x::ReverseDiff.TrackedArray) = x.value
31+
SciMLBase.unitfulvalue(x::ReverseDiff.TrackedReal) = x.value
32+
SciMLBase.unitfulvalue(x::ReverseDiff.TrackedArray) = x.value
3033

3134
# Force TrackedArray from TrackedReal when reshaping W\b
32-
DiffEqBase._reshape(v::AbstractVector{<:ReverseDiff.TrackedReal}, siz) = reduce(vcat, v)
35+
SciMLBase._reshape(v::AbstractVector{<:ReverseDiff.TrackedReal}, siz) = reduce(vcat, v)
3336

34-
DiffEqBase.promote_u0(u0::ReverseDiff.TrackedArray, p::ReverseDiff.TrackedArray, t0) = u0
35-
function DiffEqBase.promote_u0(u0::AbstractArray{<:ReverseDiff.TrackedReal},
37+
SciMLBase.promote_u0(u0::ReverseDiff.TrackedArray, p::ReverseDiff.TrackedArray, t0) = u0
38+
function SciMLBase.promote_u0(u0::AbstractArray{<:ReverseDiff.TrackedReal},
3639
p::ReverseDiff.TrackedArray, t0)
3740
u0
3841
end
39-
function DiffEqBase.promote_u0(u0::ReverseDiff.TrackedArray,
42+
function SciMLBase.promote_u0(u0::ReverseDiff.TrackedArray,
4043
p::AbstractArray{<:ReverseDiff.TrackedReal}, t0)
4144
u0
4245
end
43-
function DiffEqBase.promote_u0(u0::AbstractArray{<:ReverseDiff.TrackedReal},
46+
function SciMLBase.promote_u0(u0::AbstractArray{<:ReverseDiff.TrackedReal},
4447
p::AbstractArray{<:ReverseDiff.TrackedReal}, t0)
4548
u0
4649
end
47-
DiffEqBase.promote_u0(u0, p::ReverseDiff.TrackedArray, t0) = ReverseDiff.track(u0)
48-
function DiffEqBase.promote_u0(
50+
SciMLBase.promote_u0(u0, p::ReverseDiff.TrackedArray, t0) = ReverseDiff.track(u0)
51+
function SciMLBase.promote_u0(
4952
u0, p::ReverseDiff.TrackedArray{T}, t0) where {T <: ReverseDiff.ForwardDiff.Dual}
5053
ReverseDiff.track(T.(u0))
5154
end
52-
DiffEqBase.promote_u0(u0, p::AbstractArray{<:ReverseDiff.TrackedReal}, t0) = eltype(p).(u0)
55+
SciMLBase.promote_u0(u0, p::AbstractArray{<:ReverseDiff.TrackedReal}, t0) = eltype(p).(u0)
5356

5457
end

0 commit comments

Comments
 (0)