11module DiffEqBaseForwardDiffExt
22
33using DiffEqBase, ForwardDiff
4+ using DiffEqBase. ArrayInterface
45using DiffEqBase: Void, FunctionWrappersWrappers, OrdinaryDiffEqTag
56import DiffEqBase: hasdualpromote, wrapfun_oop, wrapfun_iip, promote_u0, prob2dtmin,
6- promote_tspan, anyeltypedual, isdualtype, value, ODE_DEFAULT_NORM, InternalITP,
7- nextfloat_tdir
7+ promote_tspan, anyeltypedual, isdualtype, value, ODE_DEFAULT_NORM,
8+ InternalITP,
9+ nextfloat_tdir, promote_dual
810
9- const DUALCHECK_RECURSION_MAX = 10
10-
11- eltypedual (x) = eltype (x) <: ForwardDiff.Dual
12- isdualtype (:: Type{<:ForwardDiff.Dual} ) = true
11+ eltypedual (x) = eltype (x) <: ForwardDiff.Dual
12+ isdualtype (:: Type{<:ForwardDiff.Dual} ) = true
1313const dualT = ForwardDiff. Dual{ForwardDiff. Tag{OrdinaryDiffEqTag, Float64}, Float64, 1 }
1414dualgen (:: Type{T} ) where {T} = ForwardDiff. Dual{ForwardDiff. Tag{OrdinaryDiffEqTag, T}, T, 1 }
1515
@@ -24,12 +24,14 @@ function prob2dtmin(tspan, ::ForwardDiff.Dual, use_end_time)
2424 end
2525end
2626
27- hasdualpromote (u0,t:: Number ) = hasmethod (ArrayInterface. promote_eltype,
28- Tuple{Type{typeof (u0)}, Type{dualgen (eltype (u0))}}) &&
29- hasmethod (promote_rule,
30- Tuple{Type{eltype (u0)}, Type{dualgen (eltype (u0))}}) &&
31- hasmethod (promote_rule,
32- Tuple{Type{eltype (u0)}, Type{typeof (t)}})
27+ function hasdualpromote (u0, t:: Number )
28+ hasmethod (ArrayInterface. promote_eltype,
29+ Tuple{Type{typeof (u0)}, Type{dualgen (eltype (u0))}}) &&
30+ hasmethod (promote_rule,
31+ Tuple{Type{eltype (u0)}, Type{dualgen (eltype (u0))}}) &&
32+ hasmethod (promote_rule,
33+ Tuple{Type{eltype (u0)}, Type{typeof (t)}})
34+ end
3335
3436const NORECOMPILE_IIP_SUPPORTED_ARGS = (
3537 Tuple{Vector{Float64}, Vector{Float64},
@@ -108,16 +110,6 @@ function wrapfun_iip(@nospecialize(ff))
108110 FunctionWrappersWrappers. FunctionWrappersWrapper {typeof(fwt), false} (fwt)
109111end
110112
111- """
112- promote_dual(::Type{T},::Type{T2})
113-
114-
115- Is like the number promotion system, but always prefers a dual number type above
116- anything else. For higher order differentiation, it returns the most dualiest of
117- them all. This is then used to promote `u0` into the suspected highest differentiation
118- space for solving the equation.
119- """
120- promote_dual (:: Type{T} , :: Type{T2} ) where {T, T2} = T
121113promote_dual (:: Type{T} , :: Type{T2} ) where {T <: ForwardDiff.Dual , T2} = T
122114function promote_dual (:: Type{T} ,
123115 :: Type{T2} ) where {T <: ForwardDiff.Dual , T2 <: ForwardDiff.Dual }
227219
228220# Static Arrays don't support the `init` keyword argument for `sum`
229221@inline __sum (f:: F , args... ; init, kwargs... ) where {F} = sum (f, args... ; init, kwargs... )
230- @inline function __sum (f:: F , a:: DiffEqBase.StaticArraysCore.StaticArray... ; init, kwargs... ) where {F}
222+ @inline function __sum (
223+ f:: F , a:: DiffEqBase.StaticArraysCore.StaticArray... ; init, kwargs... ) where {F}
231224 return mapreduce (f, + , a... ; init, kwargs... )
232225end
233226
@@ -292,4 +285,4 @@ function SciMLBase.solve(
292285 right = ForwardDiff. Dual {T, V, P} (sol. right, partials))
293286end
294287
295- end
288+ end
0 commit comments