Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 58 additions & 1 deletion ext/DiffEqBaseForwardDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,71 @@ module DiffEqBaseForwardDiffExt
using DiffEqBase, ForwardDiff
using DiffEqBase.ArrayInterface
using DiffEqBase: Void, FunctionWrappersWrappers, OrdinaryDiffEqTag,
DEIIPFunctionWrapperForwardDiff,
AbstractTimeseriesSolution,
RecursiveArrayTools, reduce_tup, _promote_tspan, has_continuous_callback
using FunctionWrappers: FunctionWrapper
import DiffEqBase: hasdualpromote, wrapfun_oop, wrapfun_iip, prob2dtmin,
promote_tspan, ODE_DEFAULT_NORM
import SciMLBase: isdualtype, DualEltypeChecker, sse, __sum

const dualT = ForwardDiff.Dual{ForwardDiff.Tag{OrdinaryDiffEqTag, Float64}, Float64, 1}
dualgen(::Type{T}) where {T} = ForwardDiff.Dual{ForwardDiff.Tag{OrdinaryDiffEqTag, T}, T, 1}

"""
ODEDualTag

Type alias for the ForwardDiff tag used by ODE solvers:
`ForwardDiff.Tag{OrdinaryDiffEqTag, Float64}`.

Access from downstream packages via:
```julia
ext = Base.get_extension(DiffEqBase, :DiffEqBaseForwardDiffExt)
ext.ODEDualTag
```
"""
const ODEDualTag = ForwardDiff.Tag{OrdinaryDiffEqTag, Float64}

"""
ODEDualType

Type alias for the ForwardDiff dual number used by ODE solvers:
`ForwardDiff.Dual{ODEDualTag, Float64, 1}`.

Access from downstream packages via:
```julia
ext = Base.get_extension(DiffEqBase, :DiffEqBaseForwardDiffExt)
ext.ODEDualType
```
"""
const ODEDualType = ForwardDiff.Dual{ODEDualTag, Float64, 1}

"""
DEIIPFunctionWrapperForwardDiffVF64{pType}

VF64-specialized alias for `DEIIPFunctionWrapperForwardDiff` matching the common
in-place `Vector{Float64}` ODE case with ForwardDiff support.

Equivalent to:
```julia
DEIIPFunctionWrapperForwardDiff{
Vector{Float64}, Vector{Float64}, pType, Float64,
Vector{ODEDualType}, Vector{ODEDualType}, ODEDualType,
}
```

Access from downstream packages via:
```julia
ext = Base.get_extension(DiffEqBase, :DiffEqBaseForwardDiffExt)
ext.DEIIPFunctionWrapperForwardDiffVF64
```
"""
const DEIIPFunctionWrapperForwardDiffVF64{pType} =
DEIIPFunctionWrapperForwardDiff{
Vector{Float64}, Vector{Float64}, pType, Float64,
Vector{dualT}, Vector{dualT}, dualT,
}

const NORECOMPILE_IIP_SUPPORTED_ARGS = (
Tuple{
Vector{Float64}, Vector{Float64},
Expand Down Expand Up @@ -82,7 +138,8 @@ function wrapfun_iip(
fwt = map(iip_arglists, iip_returnlists) do A, R
FunctionWrappersWrappers.FunctionWrappers.FunctionWrapper{R, A}(Void(ff))
end
return FunctionWrappersWrappers.FunctionWrappersWrapper{typeof(fwt), false}(fwt)
inner = FunctionWrappersWrappers.FunctionWrappersWrapper{typeof(fwt), false}(fwt)
return DEIIPFunctionWrapperForwardDiff(inner)
end

const iip_arglists_default = (
Expand Down
7 changes: 7 additions & 0 deletions src/DiffEqBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,9 @@ include("common_defaults.jl")
include("solve.jl")
include("internal_euler.jl")
include("norecompile.jl")
# unwrapped_f support for DE wrapper structs (delegates to inner FunctionWrappersWrapper)
unwrapped_f(f::DEIIPFunctionWrapper) = unwrapped_f(f.fw)
unwrapped_f(f::DEIIPFunctionWrapperForwardDiff) = unwrapped_f(f.fw)
include("integrator_accessors.jl")

# This is only used for oop stiff solvers
Expand All @@ -175,6 +178,10 @@ export initialize!, finalize!

export SensitivityADPassThrough

# FunctionWrapper structs and aliases for the VF64 pattern
export DEIIPFunctionWrapper, DEIIPFunctionWrapperVF64,
DEIIPFunctionWrapperForwardDiff, AnyFunctionWrapper, wrapfun_iip_simple

include("precompilation.jl")

end # module
92 changes: 92 additions & 0 deletions src/norecompile.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,98 @@ function unwrap_fw(fw::FunctionWrapper)
return fw.obj[]
end

"""
DEIIPFunctionWrapper{duType, uType, pType, tType}

Wrapper struct around a `FunctionWrappersWrapper` containing a single `FunctionWrapper`
for an in-place function `f!(du, u, p, t) -> Nothing`.

Compared to a raw `FunctionWrappersWrapper`, this struct exposes only 4 type
parameters (du, u, p, t types) instead of the full nested
`FunctionWrappersWrapper{Tuple{FunctionWrapper{...}}, false}` type, which
significantly reduces type string length in stack traces.

Used by solvers that do **not** require ForwardDiff internally (e.g. Tsit5, Verner).
See also [`DEIIPFunctionWrapperForwardDiff`](@ref) for the ForwardDiff-aware variant.
"""
struct DEIIPFunctionWrapper{duType, uType, pType, tType}
fw::FunctionWrappersWrappers.FunctionWrappersWrapper{
Tuple{FunctionWrapper{Nothing, Tuple{duType, uType, pType, tType}}},
false,
}
end

(f::DEIIPFunctionWrapper)(args...) = f.fw(args...)
SciMLBase.isfunctionwrapper(::DEIIPFunctionWrapper) = true

"""
DEIIPFunctionWrapperVF64{pType}

VF64-specialized alias: `DEIIPFunctionWrapper{Vector{Float64}, Vector{Float64}, pType, Float64}`.
Matches the wrapper produced for the common in-place `Vector{Float64}` ODE case
when ForwardDiff is **not** used by the solver.
"""
const DEIIPFunctionWrapperVF64{pType} =
DEIIPFunctionWrapper{Vector{Float64}, Vector{Float64}, pType, Float64}

"""
DEIIPFunctionWrapperForwardDiff{T1, T2, T3, T4, dT1, dT2, dT4}

Wrapper struct around a `FunctionWrappersWrapper` containing 4 `FunctionWrapper`
entries for an in-place function `f!(du, u, p, t) -> Nothing` with ForwardDiff support.

The 4 wrappers cover:
1. Base types: `(T1, T2, T3, T4)`
2. Dual state: `(dT1, dT2, T3, T4)`
3. Dual time: `(dT1, T2, T3, dT4)`
4. Dual state+time: `(dT1, dT2, T3, dT4)`

Compared to a raw `FunctionWrappersWrapper`, this struct exposes 7 type parameters
instead of repeating the full `FunctionWrapper{Nothing, Tuple{...}}` 4 times,
significantly reducing type string length in stack traces.

Used by solvers that require ForwardDiff internally (e.g. Rosenbrock, implicit methods).
See also [`DEIIPFunctionWrapper`](@ref) for the simpler non-ForwardDiff variant.
"""
struct DEIIPFunctionWrapperForwardDiff{T1, T2, T3, T4, dT1, dT2, dT4}
fw::FunctionWrappersWrappers.FunctionWrappersWrapper{
Tuple{
FunctionWrapper{Nothing, Tuple{T1, T2, T3, T4}},
FunctionWrapper{Nothing, Tuple{dT1, dT2, T3, T4}},
FunctionWrapper{Nothing, Tuple{dT1, T2, T3, dT4}},
FunctionWrapper{Nothing, Tuple{dT1, dT2, T3, dT4}},
},
false,
}
end

(f::DEIIPFunctionWrapperForwardDiff)(args...) = f.fw(args...)
SciMLBase.isfunctionwrapper(::DEIIPFunctionWrapperForwardDiff) = true

# Union for isa checks (avoid double-wrapping)
const AnyFunctionWrapper = Union{
FunctionWrappersWrappers.FunctionWrappersWrapper,
DEIIPFunctionWrapper,
DEIIPFunctionWrapperForwardDiff,
}

"""
wrapfun_iip_simple(ff, du, u, p, t)

Wrap an in-place function `ff(du, u, p, t) -> Nothing` into a [`DEIIPFunctionWrapper`](@ref)
(single `FunctionWrapper`, no ForwardDiff support).

Unlike [`wrapfun_iip`](@ref), this function is **not** overridden by the ForwardDiff
extension, so it avoids creating method-table backedges that cause invalidation.
Use this when the solver does not need ForwardDiff internally.
"""
function wrapfun_iip_simple(ff, du, u, p, t)
inner = FunctionWrappersWrappers.FunctionWrappersWrapper(
Void(ff), (typeof((du, u, p, t)),), (Nothing,)
)
return DEIIPFunctionWrapper(inner)
end

# Default dispatch assumes no ForwardDiff, gets added in the new dispatch
function wrapfun_iip(ff, inputs)
return FunctionWrappersWrappers.FunctionWrappersWrapper(
Expand Down
17 changes: 6 additions & 11 deletions src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -790,19 +790,19 @@ function promote_f(f::F, ::Val{specialize}, u0, p, t, ::Val{true}) where {F, spe
) ||
(
specialize === SciMLBase.FunctionWrapperSpecialize &&
!(f.f isa FunctionWrappersWrappers.FunctionWrappersWrapper)
!(f.f isa AnyFunctionWrapper)
)
)
# Wrap tgrad if present, so its type is also erased.
# tgrad!(dT, u, p, t) -> Nothing has the same shape as the RHS.
if f.tgrad !== nothing && !(f.tgrad isa FunctionWrappersWrappers.FunctionWrappersWrapper)
if f.tgrad !== nothing && !(f.tgrad isa AnyFunctionWrapper)
f = @set f.tgrad = wrapfun_jac_iip(f.tgrad, (u0, u0, p, t))
end
# Wrap the Jacobian if present, so its type is also erased
if f.jac !== nothing && !(f.jac isa FunctionWrappersWrappers.FunctionWrappersWrapper)
if f.jac !== nothing && !(f.jac isa AnyFunctionWrapper)
n = length(u0)
J_proto = f.jac_prototype !== nothing ? similar(f.jac_prototype, uElType) :
zeros(uElType, n, n)
zeros(uElType, n, n)
f = @set f.jac = wrapfun_jac_iip(f.jac, (J_proto, u0, p, t))
end
return unwrapped_f(f, wrapfun_iip(f.f, (u0, u0, p, t)))
Expand Down Expand Up @@ -833,15 +833,10 @@ function promote_f(f::F, ::Val{specialize}, u0, p, t, ::Val{false}) where {F, sp
) ||
(
specialize === SciMLBase.FunctionWrapperSpecialize &&
!(f.f isa FunctionWrappersWrappers.FunctionWrappersWrapper)
)
)
return unwrapped_f(
f,
FunctionWrappersWrappers.FunctionWrappersWrapper(
Void(f.f), (typeof((u0, u0, p, t)),), (Nothing,)
!(f.f isa AnyFunctionWrapper)
)
)
return unwrapped_f(f, wrapfun_iip_simple(f.f, u0, u0, p, t))
else
return f
end
Expand Down
8 changes: 5 additions & 3 deletions test/downstream/callback_ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,16 @@ end

function loss_zyg(p)
prob = ODEProblem(f_zyg, [1.0, 0.0], (0.0, 1.0), p)
sol = solve(prob, Tsit5(), callback = cb_zyg,
abstol = 1e-14, reltol = 1e-14, save_everystep = false)
sol = solve(
prob, Tsit5(), callback = cb_zyg,
abstol = 1.0e-14, reltol = 1.0e-14, save_everystep = false
)
return sum(sol.u[end])
end

p = [9.8, 0.8]
grad = Zygote.gradient(loss_zyg, p)[1]
findiff_grad = FiniteDiff.finite_difference_gradient(loss_zyg, p)
@test all(isfinite, grad)
@test grad ≈ findiff_grad rtol = 1e-3
@test grad ≈ findiff_grad rtol = 1.0e-3
end
132 changes: 132 additions & 0 deletions test/function_wrapper_aliases.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
using DiffEqBase, ForwardDiff, Test
using DiffEqBase: Void, FunctionWrappersWrappers, OrdinaryDiffEqTag,
wrapfun_iip, wrapfun_iip_simple, AnyFunctionWrapper
using FunctionWrappers: FunctionWrapper

# Get the ForwardDiff extension module
const FDExt = Base.get_extension(DiffEqBase, :DiffEqBaseForwardDiffExt)

@testset "DEIIPFunctionWrapper struct (no ForwardDiff)" begin
ff = (du, u, p, t) -> (du .= u; nothing)
du = zeros(3)
u = ones(3)
p = [1.0, 2.0]
t = 0.0

# wrapfun_iip_simple should produce a DEIIPFunctionWrapper
wrapped = wrapfun_iip_simple(ff, du, u, p, t)
@test wrapped isa DiffEqBase.DEIIPFunctionWrapper
@test wrapped isa DiffEqBase.DEIIPFunctionWrapper{
Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64,
}

# VF64 alias should match
@test wrapped isa DiffEqBase.DEIIPFunctionWrapperVF64{Vector{Float64}}

# Should match AnyFunctionWrapper union
@test wrapped isa AnyFunctionWrapper

# The wrapped function should be callable and produce correct results
du_test = zeros(3)
wrapped(du_test, u, p, t)
@test du_test == u

# isfunctionwrapper should return true
@test SciMLBase.isfunctionwrapper(wrapped)

# Stack trace type string should be short
type_str = string(typeof(wrapped))
@test occursin("DEIIPFunctionWrapper", type_str)
@test !occursin("FunctionWrappersWrapper", type_str)
end

@testset "DEIIPFunctionWrapperVF64 with NullParameters" begin
ff = (du, u, p, t) -> (du .= u; nothing)
du = zeros(3)
u = ones(3)
p = SciMLBase.NullParameters()
t = 0.0

wrapped = wrapfun_iip_simple(ff, du, u, p, t)
@test wrapped isa DiffEqBase.DEIIPFunctionWrapper
@test wrapped isa DiffEqBase.DEIIPFunctionWrapperVF64{SciMLBase.NullParameters}
end

@testset "ODEDualTag and ODEDualType (ForwardDiff extension)" begin
@test FDExt.ODEDualTag === ForwardDiff.Tag{OrdinaryDiffEqTag, Float64}
@test FDExt.ODEDualType === ForwardDiff.Dual{
ForwardDiff.Tag{OrdinaryDiffEqTag, Float64}, Float64, 1,
}
end

@testset "DEIIPFunctionWrapperForwardDiff struct" begin
ff = (du, u, p, t) -> (du .= u; nothing)
du = zeros(3)
u = ones(3)
p = [1.0, 2.0]
t = 0.0

# wrapfun_iip with ForwardDiff loaded should produce a DEIIPFunctionWrapperForwardDiff
wrapped = wrapfun_iip(ff, (du, u, p, t))
@test wrapped isa DiffEqBase.DEIIPFunctionWrapperForwardDiff

# VF64 alias should match
@test wrapped isa FDExt.DEIIPFunctionWrapperForwardDiffVF64{Vector{Float64}}

# Should match AnyFunctionWrapper union
@test wrapped isa AnyFunctionWrapper

# The wrapped function should be callable
du_test = zeros(3)
wrapped(du_test, u, p, t)
@test du_test == u

# isfunctionwrapper should return true
@test SciMLBase.isfunctionwrapper(wrapped)

# Stack trace type string should NOT contain FunctionWrappersWrapper
type_str = string(typeof(wrapped))
@test occursin("DEIIPFunctionWrapperForwardDiff", type_str)
@test !occursin("FunctionWrappersWrapper", type_str)
end

@testset "DEIIPFunctionWrapperForwardDiffVF64 with NullParameters" begin
ff = (du, u, p, t) -> (du .= u; nothing)

# Default wrapfun_iip (no args) produces the 7-wrapper variant, not 4-wrapper
wrapped_default = wrapfun_iip(ff)
# The default 7-wrapper has a different structure (7 entries, not 4),
# so it should NOT be a DEIIPFunctionWrapperForwardDiff
@test !(wrapped_default isa DiffEqBase.DEIIPFunctionWrapperForwardDiff)
# But it IS still an AnyFunctionWrapper (raw FunctionWrappersWrapper)
@test wrapped_default isa AnyFunctionWrapper

# With explicit 4-tuple args and NullParameters, it should match
du = zeros(3)
u = ones(3)
p = SciMLBase.NullParameters()
t = 0.0
wrapped = wrapfun_iip(ff, (du, u, p, t))
@test wrapped isa DiffEqBase.DEIIPFunctionWrapperForwardDiff
@test wrapped isa FDExt.DEIIPFunctionWrapperForwardDiffVF64{SciMLBase.NullParameters}
end

@testset "wrapfun_iip_simple does not change behavior with ForwardDiff loaded" begin
ff = (du, u, p, t) -> (du .= u; nothing)
du = zeros(3)
u = ones(3)
p = [1.0, 2.0]
t = 0.0

# wrapfun_iip_simple should ALWAYS produce a DEIIPFunctionWrapper, even with ForwardDiff loaded
wrapped_simple = wrapfun_iip_simple(ff, du, u, p, t)
@test wrapped_simple isa DiffEqBase.DEIIPFunctionWrapper

# It should NOT match the ForwardDiff 4-wrapper struct
@test !(wrapped_simple isa DiffEqBase.DEIIPFunctionWrapperForwardDiff)

# wrapfun_iip should produce a DEIIPFunctionWrapperForwardDiff (ForwardDiff ext wraps it)
wrapped_fd = wrapfun_iip(ff, (du, u, p, t))
@test wrapped_fd isa DiffEqBase.DEIIPFunctionWrapperForwardDiff
@test !(wrapped_fd isa DiffEqBase.DEIIPFunctionWrapper)
end
Loading
Loading