Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,13 @@ EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898"
FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a"
FastPower = "a4df4552-cc26-4903-aec0-212e50a0e84b"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
FunctionWrappers = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e"
FunctionWrappersWrappers = "77dc65aa-8811-40c2-897b-53d922fa7daf"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
MuladdMacro = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221"
Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a"
PreallocationTools = "d236fae5-4411-538c-8e31-a6e3d9e00b46"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Expand All @@ -40,6 +38,7 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
GeneralizedGenerated = "6b9d7cbe-bcb9-11e9-073f-15a7a543e2eb"
GTPSA = "b27dd330-f138-47c5-815b-40db9dd9b6e8"
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
Expand All @@ -55,6 +54,7 @@ DiffEqBaseCUDAExt = "CUDA"
DiffEqBaseChainRulesCoreExt = "ChainRulesCore"
DiffEqBaseDistributionsExt = "Distributions"
DiffEqBaseEnzymeExt = ["ChainRulesCore", "Enzyme"]
DiffEqBaseForwardDiffExt = ["ForwardDiff"]
DiffEqBaseGeneralizedGeneratedExt = "GeneralizedGenerated"
DiffEqBaseGTPSAExt = "GTPSA"
DiffEqBaseMPIExt = "MPI"
Expand Down Expand Up @@ -92,7 +92,6 @@ Measurements = "2"
MonteCarloMeasurements = "1"
MuladdMacro = "0.2.1"
Parameters = "0.12.0"
PreallocationTools = "0.4"
PrecompileTools = "1"
Printf = "1.9"
RecursiveArrayTools = "3"
Expand Down
286 changes: 286 additions & 0 deletions ext/DiffEqBaseForwardDiffExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,286 @@
module DiffEqBaseForwardDiffExt

using DiffEqBase, ForwardDiff
using DiffEqBase.ArrayInterface
using DiffEqBase: Void, FunctionWrappersWrappers, OrdinaryDiffEqTag
import DiffEqBase: hasdualpromote, wrapfun_oop, wrapfun_iip, promote_u0, prob2dtmin,
promote_tspan, anyeltypedual, isdualtype, value, ODE_DEFAULT_NORM,
InternalITP,
nextfloat_tdir, promote_dual

eltypedual(x) = eltype(x) <: ForwardDiff.Dual
isdualtype(::Type{<:ForwardDiff.Dual}) = true
const dualT = ForwardDiff.Dual{ForwardDiff.Tag{OrdinaryDiffEqTag, Float64}, Float64, 1}
dualgen(::Type{T}) where {T} = ForwardDiff.Dual{ForwardDiff.Tag{OrdinaryDiffEqTag, T}, T, 1}

# Copy of the other prob2dtmin dispatch, just for optionality
function prob2dtmin(tspan, ::ForwardDiff.Dual, use_end_time)
t1, t2 = tspan
isfinite(t1) || throw(ArgumentError("t0 in the tspan `(t0, t1)` must be finite"))
if use_end_time && isfinite(t2 - t1)
return max(eps(t2), eps(t1))
else
return max(eps(typeof(t1)), eps(t1))
end
end

function hasdualpromote(u0, t::Number)
hasmethod(ArrayInterface.promote_eltype,
Tuple{Type{typeof(u0)}, Type{dualgen(eltype(u0))}}) &&
hasmethod(promote_rule,
Tuple{Type{eltype(u0)}, Type{dualgen(eltype(u0))}}) &&
hasmethod(promote_rule,
Tuple{Type{eltype(u0)}, Type{typeof(t)}})
end

const NORECOMPILE_IIP_SUPPORTED_ARGS = (
Tuple{Vector{Float64}, Vector{Float64},
Vector{Float64}, Float64},
Tuple{Vector{Float64}, Vector{Float64},
SciMLBase.NullParameters, Float64})

const oop_arglists = (Tuple{Vector{Float64}, Vector{Float64}, Float64},
Tuple{Vector{Float64}, SciMLBase.NullParameters, Float64},
Tuple{Vector{Float64}, Vector{Float64}, dualT},
Tuple{Vector{dualT}, Vector{Float64}, Float64},
Tuple{Vector{dualT}, SciMLBase.NullParameters, Float64},
Tuple{Vector{Float64}, SciMLBase.NullParameters, dualT})

const NORECOMPILE_OOP_SUPPORTED_ARGS = (Tuple{Vector{Float64},
Vector{Float64}, Float64},
Tuple{Vector{Float64},
SciMLBase.NullParameters, Float64})
const oop_returnlists = (Vector{Float64}, Vector{Float64},
ntuple(x -> Vector{dualT}, length(oop_arglists) - 2)...)

function wrapfun_oop(ff, inputs::Tuple = ())
if !isempty(inputs)
IT = Tuple{map(typeof, inputs)...}
if IT ∉ NORECOMPILE_OOP_SUPPORTED_ARGS
throw(NoRecompileArgumentError(IT))
end
end
FunctionWrappersWrappers.FunctionWrappersWrapper(ff, oop_arglists,
oop_returnlists)
end

function wrapfun_iip(ff,
inputs::Tuple{T1, T2, T3, T4}) where {T1, T2, T3, T4}
T = eltype(T2)
dualT = dualgen(T)
dualT1 = ArrayInterface.promote_eltype(T1, dualT)
dualT2 = ArrayInterface.promote_eltype(T2, dualT)
dualT4 = dualgen(promote_type(T, T4))

iip_arglists = (Tuple{T1, T2, T3, T4},
Tuple{dualT1, dualT2, T3, T4},
Tuple{dualT1, T2, T3, dualT4},
Tuple{dualT1, dualT2, T3, dualT4})

iip_returnlists = ntuple(x -> Nothing, 4)

fwt = map(iip_arglists, iip_returnlists) do A, R
FunctionWrappersWrappers.FunctionWrappers.FunctionWrapper{R, A}(Void(ff))
end
FunctionWrappersWrappers.FunctionWrappersWrapper{typeof(fwt), false}(fwt)
end

const iip_arglists_default = (
Tuple{Vector{Float64}, Vector{Float64}, Vector{Float64},
Float64},
Tuple{Vector{Float64}, Vector{Float64},
SciMLBase.NullParameters,
Float64
},
Tuple{Vector{dualT}, Vector{Float64}, Vector{Float64}, dualT},
Tuple{Vector{dualT}, Vector{dualT}, Vector{Float64}, dualT},
Tuple{Vector{dualT}, Vector{dualT}, Vector{Float64}, Float64},
Tuple{Vector{dualT}, Vector{dualT}, SciMLBase.NullParameters,
Float64
},
Tuple{Vector{dualT}, Vector{Float64},
SciMLBase.NullParameters, dualT
})
const iip_returnlists_default = ntuple(x -> Nothing, length(iip_arglists_default))

function wrapfun_iip(@nospecialize(ff))
fwt = map(iip_arglists_default, iip_returnlists_default) do A, R
FunctionWrappersWrappers.FunctionWrappers.FunctionWrapper{R, A}(Void(ff))
end
FunctionWrappersWrappers.FunctionWrappersWrapper{typeof(fwt), false}(fwt)
end

promote_dual(::Type{T}, ::Type{T2}) where {T <: ForwardDiff.Dual, T2} = T
function promote_dual(::Type{T},
::Type{T2}) where {T <: ForwardDiff.Dual, T2 <: ForwardDiff.Dual}
T
end
promote_dual(::Type{T}, ::Type{T2}) where {T, T2 <: ForwardDiff.Dual} = T2

function promote_dual(::Type{T},
::Type{T2}) where {T3, T4, V, V2 <: ForwardDiff.Dual, N, N2,
T <: ForwardDiff.Dual{T3, V, N},
T2 <: ForwardDiff.Dual{T4, V2, N2}}
T2
end
function promote_dual(::Type{T},
::Type{T2}) where {T3, T4, V <: ForwardDiff.Dual, V2, N, N2,
T <: ForwardDiff.Dual{T3, V, N},
T2 <: ForwardDiff.Dual{T4, V2, N2}}
T
end
function promote_dual(::Type{T},
::Type{T2}) where {
T3, V <: ForwardDiff.Dual, V2 <: ForwardDiff.Dual,
N,
T <: ForwardDiff.Dual{T3, V, N},
T2 <: ForwardDiff.Dual{T3, V2, N}}
ForwardDiff.Dual{T3, promote_dual(V, V2), N}
end

# Opt out since these are using for preallocation, not differentiation
function anyeltypedual(x::Union{ForwardDiff.AbstractConfig, Module},
::Type{Val{counter}} = Val{0}) where {counter}
Any
end
function anyeltypedual(x::Type{T},
::Type{Val{counter}} = Val{0}) where {counter} where {T <:
ForwardDiff.AbstractConfig}
Any
end

function anyeltypedual(x::ForwardDiff.DiffResults.DiffResult,
::Type{Val{counter}} = Val{0}) where {counter}
Any
end
function anyeltypedual(x::Type{T},
::Type{Val{counter}} = Val{0}) where {counter} where {T <:
ForwardDiff.DiffResults.DiffResult}
Any
end

function anyeltypedual(::Type{T},
::Type{Val{counter}} = Val{0}) where {counter} where {T <: ForwardDiff.Dual}
T
end

function promote_tspan(u0::AbstractArray{<:ForwardDiff.Dual}, p, tspan, prob, kwargs)
if (haskey(kwargs, :callback) && has_continuous_callback(kwargs[:callback])) ||
(haskey(prob.kwargs, :callback) && has_continuous_callback(prob.kwargs[:callback]))
return _promote_tspan(eltype(u0).(tspan), kwargs)
else
return _promote_tspan(tspan, kwargs)
end
end

function promote_tspan(u0::AbstractArray{<:Complex{<:ForwardDiff.Dual}}, p, tspan, prob,
kwargs)
return _promote_tspan(real(eltype(u0)).(tspan), kwargs)
end

function promote_tspan(u0::AbstractArray{<:ForwardDiff.Dual}, p,
tspan::Tuple{<:ForwardDiff.Dual, <:ForwardDiff.Dual}, prob, kwargs)
return _promote_tspan(tspan, kwargs)
end

value(x::Type{ForwardDiff.Dual{T, V, N}}) where {T, V, N} = V
value(x::ForwardDiff.Dual) = value(ForwardDiff.value(x))

sse(x::Number) = abs2(x)
sse(x::ForwardDiff.Dual) = sse(ForwardDiff.value(x)) + sum(sse, ForwardDiff.partials(x))
totallength(x::Number) = 1
function totallength(x::ForwardDiff.Dual)
totallength(ForwardDiff.value(x)) + sum(totallength, ForwardDiff.partials(x))
end
totallength(x::AbstractArray) = __sum(totallength, x; init = 0)

@inline ODE_DEFAULT_NORM(u::ForwardDiff.Dual, ::Any) = sqrt(sse(u))
@inline function ODE_DEFAULT_NORM(u::AbstractArray{<:ForwardDiff.Dual{Tag, T}},
t::Any) where {Tag, T}
sqrt(__sum(sse, u; init = sse(zero(T))) / totallength(u))
end
@inline ODE_DEFAULT_NORM(u::ForwardDiff.Dual, ::ForwardDiff.Dual) = sqrt(sse(u))
@inline function ODE_DEFAULT_NORM(u::AbstractArray{<:ForwardDiff.Dual{Tag, T}},
::ForwardDiff.Dual) where {Tag, T}
sqrt(__sum(sse, u; init = sse(zero(T))) / totallength(u))
end

if !hasmethod(nextfloat, Tuple{ForwardDiff.Dual})
# Type piracy. Should upstream
function Base.nextfloat(d::ForwardDiff.Dual{T, V, N}) where {T, V, N}
ForwardDiff.Dual{T}(nextfloat(d.value), d.partials)
end
function Base.prevfloat(d::ForwardDiff.Dual{T, V, N}) where {T, V, N}
ForwardDiff.Dual{T}(prevfloat(d.value), d.partials)
end
end

# Static Arrays don't support the `init` keyword argument for `sum`
@inline __sum(f::F, args...; init, kwargs...) where {F} = sum(f, args...; init, kwargs...)
@inline function __sum(
f::F, a::DiffEqBase.StaticArraysCore.StaticArray...; init, kwargs...) where {F}
return mapreduce(f, +, a...; init, kwargs...)
end

# Differentiation of internal solver

function scalar_nlsolve_ad(prob, alg::InternalITP, args...; kwargs...)
f = prob.f
p = value(prob.p)

if prob isa IntervalNonlinearProblem
tspan = value(prob.tspan)
newprob = IntervalNonlinearProblem(f, tspan, p; prob.kwargs...)
else
u0 = value(prob.u0)
newprob = NonlinearProblem(f, u0, p; prob.kwargs...)
end

sol = solve(newprob, alg, args...; kwargs...)

uu = sol.u
if p isa Number
f_p = ForwardDiff.derivative(Base.Fix1(f, uu), p)
else
f_p = ForwardDiff.gradient(Base.Fix1(f, uu), p)
end

f_x = ForwardDiff.derivative(Base.Fix2(f, p), uu)
pp = prob.p
sumfun = let f_x′ = -f_x
((fp, p),) -> (fp / f_x′) * ForwardDiff.partials(p)
end
partials = sum(sumfun, zip(f_p, pp))
return sol, partials
end

function SciMLBase.solve(
prob::IntervalNonlinearProblem{uType, iip,
<:ForwardDiff.Dual{T, V, P}},
alg::InternalITP, args...;
kwargs...) where {uType, iip, T, V, P}
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
return SciMLBase.build_solution(prob, alg, ForwardDiff.Dual{T, V, P}(sol.u, partials),
sol.resid; retcode = sol.retcode,
left = ForwardDiff.Dual{T, V, P}(sol.left, partials),
right = ForwardDiff.Dual{T, V, P}(sol.right, partials))
end

function SciMLBase.solve(
prob::IntervalNonlinearProblem{uType, iip,
<:AbstractArray{
<:ForwardDiff.Dual{T,
V,
P},
}},
alg::InternalITP, args...;
kwargs...) where {uType, iip, T, V, P}
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)

return SciMLBase.build_solution(prob, alg, ForwardDiff.Dual{T, V, P}(sol.u, partials),
sol.resid; retcode = sol.retcode,
left = ForwardDiff.Dual{T, V, P}(sol.left, partials),
right = ForwardDiff.Dual{T, V, P}(sol.right, partials))
end

end
2 changes: 1 addition & 1 deletion ext/DiffEqBaseReverseDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ function DiffEqBase.promote_u0(u0::AbstractArray{<:ReverseDiff.TrackedReal},
end
DiffEqBase.promote_u0(u0, p::ReverseDiff.TrackedArray, t0) = ReverseDiff.track(u0)
function DiffEqBase.promote_u0(
u0, p::ReverseDiff.TrackedArray{T}, t0) where {T <: ForwardDiff.Dual}
u0, p::ReverseDiff.TrackedArray{T}, t0) where {T <: ReverseDiff.ForwardDiff.Dual}
ReverseDiff.track(T.(u0))
end
DiffEqBase.promote_u0(u0, p::AbstractArray{<:ReverseDiff.TrackedReal}, t0) = eltype(p).(u0)
Expand Down
17 changes: 9 additions & 8 deletions src/DiffEqBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ end
import PrecompileTools

import FastPower
@deprecate fastpow(x,y) FastPower.fastpower(x,y)
@deprecate fastpow(x, y) FastPower.fastpower(x, y)

using ArrayInterface

Expand All @@ -32,18 +32,13 @@ import TruncatedStacktraces

using Setfield

using ForwardDiff

using EnumX

using Markdown

using ConcreteStructs: @concrete
using FastClosures: @closure

# Could be made optional/glue
import PreallocationTools

import FunctionWrappersWrappers

using SciMLBase
Expand Down Expand Up @@ -111,6 +106,14 @@ Reexport.@reexport using SciMLBase

SciMLBase.isfunctionwrapper(x::FunctionWrapper) = true

## Extension Functions

eltypedual(x) = false
promote_u0(::Nothing, p, t0) = nothing
isdualtype(::Type{T}) where {T} = true

## Types

"""
$(TYPEDEF)
"""
Expand All @@ -132,14 +135,12 @@ include("utils.jl")
include("stats.jl")
include("calculate_residuals.jl")
include("tableaus.jl")
include("internal_falsi.jl")
include("internal_itp.jl")

include("callbacks.jl")
include("common_defaults.jl")
include("solve.jl")
include("internal_euler.jl")
include("forwarddiff.jl")
include("termination_conditions_deprecated.jl") # TODO: remove in the next major release
include("termination_conditions.jl")
include("norecompile.jl")
Expand Down
Loading
Loading