Skip to content

Commit 05caad6

Browse files
almost
1 parent 3297e59 commit 05caad6

File tree

6 files changed

+133
-108
lines changed

6 files changed

+133
-108
lines changed

ext/DiffEqBaseForwardDiffExt.jl

Lines changed: 17 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
module DiffEqBaseForwardDiffExt
22

33
using DiffEqBase, ForwardDiff
4+
using DiffEqBase.ArrayInterface
45
using DiffEqBase: Void, FunctionWrappersWrappers, OrdinaryDiffEqTag
56
import DiffEqBase: hasdualpromote, wrapfun_oop, wrapfun_iip, promote_u0, prob2dtmin,
6-
promote_tspan, anyeltypedual, isdualtype, value, ODE_DEFAULT_NORM, InternalITP,
7-
nextfloat_tdir
7+
promote_tspan, anyeltypedual, isdualtype, value, ODE_DEFAULT_NORM,
8+
InternalITP,
9+
nextfloat_tdir, promote_dual
810

9-
const DUALCHECK_RECURSION_MAX = 10
10-
11-
eltypedual(x) = eltype(x) <: ForwardDiff.Dual
12-
isdualtype(::Type{<:ForwardDiff.Dual}) = true
11+
eltypedual(x) = eltype(x) <: ForwardDiff.Dual
12+
isdualtype(::Type{<:ForwardDiff.Dual}) = true
1313
const dualT = ForwardDiff.Dual{ForwardDiff.Tag{OrdinaryDiffEqTag, Float64}, Float64, 1}
1414
dualgen(::Type{T}) where {T} = ForwardDiff.Dual{ForwardDiff.Tag{OrdinaryDiffEqTag, T}, T, 1}
1515

@@ -24,12 +24,14 @@ function prob2dtmin(tspan, ::ForwardDiff.Dual, use_end_time)
2424
end
2525
end
2626

27-
hasdualpromote(u0,t::Number) = hasmethod(ArrayInterface.promote_eltype,
28-
Tuple{Type{typeof(u0)}, Type{dualgen(eltype(u0))}}) &&
29-
hasmethod(promote_rule,
30-
Tuple{Type{eltype(u0)}, Type{dualgen(eltype(u0))}}) &&
31-
hasmethod(promote_rule,
32-
Tuple{Type{eltype(u0)}, Type{typeof(t)}})
27+
function hasdualpromote(u0, t::Number)
28+
hasmethod(ArrayInterface.promote_eltype,
29+
Tuple{Type{typeof(u0)}, Type{dualgen(eltype(u0))}}) &&
30+
hasmethod(promote_rule,
31+
Tuple{Type{eltype(u0)}, Type{dualgen(eltype(u0))}}) &&
32+
hasmethod(promote_rule,
33+
Tuple{Type{eltype(u0)}, Type{typeof(t)}})
34+
end
3335

3436
const NORECOMPILE_IIP_SUPPORTED_ARGS = (
3537
Tuple{Vector{Float64}, Vector{Float64},
@@ -108,16 +110,6 @@ function wrapfun_iip(@nospecialize(ff))
108110
FunctionWrappersWrappers.FunctionWrappersWrapper{typeof(fwt), false}(fwt)
109111
end
110112

111-
"""
112-
promote_dual(::Type{T},::Type{T2})
113-
114-
115-
Is like the number promotion system, but always prefers a dual number type above
116-
anything else. For higher order differentiation, it returns the most dualiest of
117-
them all. This is then used to promote `u0` into the suspected highest differentiation
118-
space for solving the equation.
119-
"""
120-
promote_dual(::Type{T}, ::Type{T2}) where {T, T2} = T
121113
promote_dual(::Type{T}, ::Type{T2}) where {T <: ForwardDiff.Dual, T2} = T
122114
function promote_dual(::Type{T},
123115
::Type{T2}) where {T <: ForwardDiff.Dual, T2 <: ForwardDiff.Dual}
@@ -227,7 +219,8 @@ end
227219

228220
# Static Arrays don't support the `init` keyword argument for `sum`
229221
@inline __sum(f::F, args...; init, kwargs...) where {F} = sum(f, args...; init, kwargs...)
230-
@inline function __sum(f::F, a::DiffEqBase.StaticArraysCore.StaticArray...; init, kwargs...) where {F}
222+
@inline function __sum(
223+
f::F, a::DiffEqBase.StaticArraysCore.StaticArray...; init, kwargs...) where {F}
231224
return mapreduce(f, +, a...; init, kwargs...)
232225
end
233226

@@ -292,4 +285,4 @@ function SciMLBase.solve(
292285
right = ForwardDiff.Dual{T, V, P}(sol.right, partials))
293286
end
294287

295-
end
288+
end

src/DiffEqBase.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ end
77
import PrecompileTools
88

99
import FastPower
10-
@deprecate fastpow(x,y) FastPower.fastpower(x,y)
10+
@deprecate fastpow(x, y) FastPower.fastpower(x, y)
1111

1212
using ArrayInterface
1313

@@ -110,7 +110,7 @@ SciMLBase.isfunctionwrapper(x::FunctionWrapper) = true
110110

111111
eltypedual(x) = false
112112
promote_u0(::Nothing, p, t0) = nothing
113-
isdualtype(::Type{T}) where T = true
113+
isdualtype(::Type{T}) where {T} = true
114114

115115
## Types
116116

src/internal_itp.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,20 @@
1+
"""
2+
prevfloat_tdir(x, x0, x1)
3+
4+
Move `x` one floating point towards x0.
5+
"""
6+
function prevfloat_tdir(x, x0, x1)
7+
x1 > x0 ? prevfloat(x) : nextfloat(x)
8+
end
9+
10+
function nextfloat_tdir(x, x0, x1)
11+
x1 > x0 ? nextfloat(x) : prevfloat(x)
12+
end
13+
14+
function max_tdir(a, b, x0, x1)
15+
x1 > x0 ? max(a, b) : min(a, b)
16+
end
17+
118
"""
219
`InternalITP`: A non-allocating ITP method, internal to DiffEqBase for
320
simpler dependencies.

src/norecompile.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,4 @@ end
2424
# Default dispatch assumes no ForwardDiff, gets added in the new dispatch
2525
function wrapfun_iip(ff, inputs)
2626
FunctionWrappersWrappers.FunctionWrappers.FunctionWrapper{Nothing, typeof(inputs)}(Void(ff))
27-
end
27+
end

src/solve.jl

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -574,7 +574,8 @@ function init_up(prob::AbstractDEProblem, sensealg, u0, p, args...; kwargs...)
574574
if tstops === nothing && has_kwargs(prob)
575575
tstops = get(prob.kwargs, :tstops, nothing)
576576
end
577-
if !(tstops isa Union{Nothing, AbstractArray, Tuple, Real}) && !SciMLBase.allows_late_binding_tstops(alg)
577+
if !(tstops isa Union{Nothing, AbstractArray, Tuple, Real}) &&
578+
!SciMLBase.allows_late_binding_tstops(alg)
578579
throw(LateBindingTstopsNotSupportedError())
579580
end
580581
_prob = get_concrete_problem(prob, isadaptive(alg); u0 = u0, p = p, kwargs...)
@@ -1110,7 +1111,8 @@ function solve_up(prob::Union{AbstractDEProblem, NonlinearProblem}, sensealg, u0
11101111
if tstops === nothing && has_kwargs(prob)
11111112
tstops = get(prob.kwargs, :tstops, nothing)
11121113
end
1113-
if !(tstops isa Union{Nothing, AbstractArray, Tuple, Real}) && !SciMLBase.allows_late_binding_tstops(alg)
1114+
if !(tstops isa Union{Nothing, AbstractArray, Tuple, Real}) &&
1115+
!SciMLBase.allows_late_binding_tstops(alg)
11141116
throw(LateBindingTstopsNotSupportedError())
11151117
end
11161118
_prob = get_concrete_problem(prob, isadaptive(alg); u0 = u0, p = p, kwargs...)
@@ -1283,23 +1285,23 @@ function promote_f(f::F, ::Val{specialize}, u0, p, t) where {F, specialize}
12831285
end
12841286

12851287
f = if f isa ODEFunction && isinplace(f) && !(f.f isa AbstractSciMLOperator) &&
1286-
# Some reinitialization code still uses NLSolvers stuff which doesn't
1287-
# properly tag, so opt-out if potentially a mass matrix DAE
1288-
f.mass_matrix isa UniformScaling &&
1289-
# Jacobians don't wrap, so just ignore those cases
1290-
f.jac === nothing &&
1291-
((specialize === SciMLBase.AutoSpecialize && eltype(u0) !== Any &&
1292-
RecursiveArrayTools.recursive_unitless_eltype(u0) === eltype(u0) &&
1293-
one(t) === oneunit(t) && hasdualpromote(u0,t)) ||
1288+
# Some reinitialization code still uses NLSolvers stuff which doesn't
1289+
# properly tag, so opt-out if potentially a mass matrix DAE
1290+
f.mass_matrix isa UniformScaling &&
1291+
# Jacobians don't wrap, so just ignore those cases
1292+
f.jac === nothing &&
1293+
((specialize === SciMLBase.AutoSpecialize && eltype(u0) !== Any &&
1294+
RecursiveArrayTools.recursive_unitless_eltype(u0) === eltype(u0) &&
1295+
one(t) === oneunit(t) && hasdualpromote(u0, t)) ||
12941296
(specialize === SciMLBase.FunctionWrapperSpecialize &&
1295-
!(f.f isa FunctionWrappersWrappers.FunctionWrappersWrapper)))
1297+
!(f.f isa FunctionWrappersWrappers.FunctionWrappersWrapper)))
12961298
return unwrapped_f(f, wrapfun_iip(f.f, (u0, u0, p, t)))
12971299
else
12981300
return f
12991301
end
13001302
end
13011303

1302-
hasdualpromote(u0,t) = true
1304+
hasdualpromote(u0, t) = true
13031305

13041306
function promote_f(f::SplitFunction, ::Val{specialize}, u0, p, t) where {specialize}
13051307
typeof(f.cache) === typeof(u0) && isinplace(f) ? f : remake(f, cache = zero(u0))

0 commit comments

Comments
 (0)