Skip to content

Commit 44e2b75

Browse files
Add simpler type aliases and constructors for IIP FunctionWrappers
Add named type aliases for the FunctionWrappersWrapper types used in the VF64 pattern (DifferentialEquations.jl#1128): - `IIPFunctionWrapper{duType, uType, pType, tType}`: 1-wrapper alias for solvers without ForwardDiff (e.g. Tsit5, Verner) - `IIPFunctionWrapperVF64{pType}`: VF64-specialized version - `wrapfun_iip_simple(ff, du, u, p, t)`: constructor that avoids ForwardDiff extension backedges - `IIPFunctionWrapperForwardDiff{T1,T2,T3,T4,dT1,dT2,dT4}`: 4-wrapper alias for ForwardDiff-aware solvers (Rosenbrock, implicit methods) - `IIPFunctionWrapperForwardDiffVF64{pType}`: VF64-specialized version - `ODEDualTag`, `ODEDualType`: named aliases for the ODE ForwardDiff tag and dual number types These aliases enable the VF64 pattern in downstream packages (OrdinaryDiffEq, StochasticDiffEq, etc.) to reference these types by name for struct field annotations and pattern matching, instead of spelling out the full FunctionWrappersWrapper type parameters. Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com> Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 6a58ddc commit 44e2b75

File tree

6 files changed

+239
-6
lines changed

6 files changed

+239
-6
lines changed

ext/DiffEqBaseForwardDiffExt.jl

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,100 @@ using DiffEqBase.ArrayInterface
55
using DiffEqBase: Void, FunctionWrappersWrappers, OrdinaryDiffEqTag,
66
AbstractTimeseriesSolution,
77
RecursiveArrayTools, reduce_tup, _promote_tspan, has_continuous_callback
8+
using FunctionWrappers: FunctionWrapper
89
import DiffEqBase: hasdualpromote, wrapfun_oop, wrapfun_iip, prob2dtmin,
910
promote_tspan, ODE_DEFAULT_NORM
1011
import SciMLBase: isdualtype, DualEltypeChecker, sse, __sum
1112

1213
const dualT = ForwardDiff.Dual{ForwardDiff.Tag{OrdinaryDiffEqTag, Float64}, Float64, 1}
1314
dualgen(::Type{T}) where {T} = ForwardDiff.Dual{ForwardDiff.Tag{OrdinaryDiffEqTag, T}, T, 1}
1415

16+
"""
17+
ODEDualTag
18+
19+
Type alias for the ForwardDiff tag used by ODE solvers:
20+
`ForwardDiff.Tag{OrdinaryDiffEqTag, Float64}`.
21+
22+
Access from downstream packages via:
23+
```julia
24+
ext = Base.get_extension(DiffEqBase, :DiffEqBaseForwardDiffExt)
25+
ext.ODEDualTag
26+
```
27+
"""
28+
const ODEDualTag = ForwardDiff.Tag{OrdinaryDiffEqTag, Float64}
29+
30+
"""
31+
ODEDualType
32+
33+
Type alias for the ForwardDiff dual number used by ODE solvers:
34+
`ForwardDiff.Dual{ODEDualTag, Float64, 1}`.
35+
36+
Access from downstream packages via:
37+
```julia
38+
ext = Base.get_extension(DiffEqBase, :DiffEqBaseForwardDiffExt)
39+
ext.ODEDualType
40+
```
41+
"""
42+
const ODEDualType = ForwardDiff.Dual{ODEDualTag, Float64, 1}
43+
44+
"""
45+
IIPFunctionWrapperForwardDiff{T1, T2, T3, T4, dT1, dT2, dT4}
46+
47+
Type alias for a `FunctionWrappersWrapper` containing 4 `FunctionWrapper` entries
48+
for an in-place function `f!(du, u, p, t) -> Nothing` with ForwardDiff support.
49+
50+
The 4 wrappers cover:
51+
1. Base types: `(T1, T2, T3, T4)`
52+
2. Dual state: `(dT1, dT2, T3, T4)`
53+
3. Dual time: `(dT1, T2, T3, dT4)`
54+
4. Dual state+time: `(dT1, dT2, T3, dT4)`
55+
56+
Used by solvers that require ForwardDiff internally (e.g. Rosenbrock, implicit methods).
57+
See also [`DiffEqBase.IIPFunctionWrapper`](@ref) for the simpler non-ForwardDiff variant.
58+
59+
Access from downstream packages via:
60+
```julia
61+
ext = Base.get_extension(DiffEqBase, :DiffEqBaseForwardDiffExt)
62+
ext.IIPFunctionWrapperForwardDiff
63+
```
64+
"""
65+
const IIPFunctionWrapperForwardDiff{T1, T2, T3, T4, dT1, dT2, dT4} =
66+
FunctionWrappersWrappers.FunctionWrappersWrapper{
67+
Tuple{
68+
FunctionWrapper{Nothing, Tuple{T1, T2, T3, T4}},
69+
FunctionWrapper{Nothing, Tuple{dT1, dT2, T3, T4}},
70+
FunctionWrapper{Nothing, Tuple{dT1, T2, T3, dT4}},
71+
FunctionWrapper{Nothing, Tuple{dT1, dT2, T3, dT4}},
72+
},
73+
false,
74+
}
75+
76+
"""
77+
IIPFunctionWrapperForwardDiffVF64{pType}
78+
79+
VF64-specialized alias for `IIPFunctionWrapperForwardDiff` matching the common
80+
in-place `Vector{Float64}` ODE case with ForwardDiff support.
81+
82+
Equivalent to:
83+
```julia
84+
IIPFunctionWrapperForwardDiff{
85+
Vector{Float64}, Vector{Float64}, pType, Float64,
86+
Vector{ODEDualType}, Vector{ODEDualType}, ODEDualType,
87+
}
88+
```
89+
90+
Access from downstream packages via:
91+
```julia
92+
ext = Base.get_extension(DiffEqBase, :DiffEqBaseForwardDiffExt)
93+
ext.IIPFunctionWrapperForwardDiffVF64
94+
```
95+
"""
96+
const IIPFunctionWrapperForwardDiffVF64{pType} =
97+
IIPFunctionWrapperForwardDiff{
98+
Vector{Float64}, Vector{Float64}, pType, Float64,
99+
Vector{dualT}, Vector{dualT}, dualT,
100+
}
101+
15102
const NORECOMPILE_IIP_SUPPORTED_ARGS = (
16103
Tuple{
17104
Vector{Float64}, Vector{Float64},

src/DiffEqBase.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,9 @@ export initialize!, finalize!
175175

176176
export SensitivityADPassThrough
177177

178+
# FunctionWrapper type aliases and constructors for the VF64 pattern
179+
export IIPFunctionWrapper, IIPFunctionWrapperVF64, wrapfun_iip_simple
180+
178181
include("precompilation.jl")
179182

180183
end # module

src/norecompile.jl

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,47 @@ function unwrap_fw(fw::FunctionWrapper)
2121
return fw.obj[]
2222
end
2323

24+
"""
25+
IIPFunctionWrapper{duType, uType, pType, tType}
26+
27+
Type alias for a `FunctionWrappersWrapper` containing a single `FunctionWrapper`
28+
for an in-place function `f!(du, u, p, t) -> Nothing`.
29+
30+
Used by solvers that do **not** require ForwardDiff internally (e.g. Tsit5, Verner).
31+
See also [`IIPFunctionWrapperForwardDiff`](@ref) for the ForwardDiff-aware variant.
32+
"""
33+
const IIPFunctionWrapper{duType, uType, pType, tType} =
34+
FunctionWrappersWrappers.FunctionWrappersWrapper{
35+
Tuple{FunctionWrapper{Nothing, Tuple{duType, uType, pType, tType}}},
36+
false,
37+
}
38+
39+
"""
40+
IIPFunctionWrapperVF64{pType}
41+
42+
VF64-specialized alias: `IIPFunctionWrapper{Vector{Float64}, Vector{Float64}, pType, Float64}`.
43+
Matches the wrapper produced for the common in-place `Vector{Float64}` ODE case
44+
when ForwardDiff is **not** used by the solver.
45+
"""
46+
const IIPFunctionWrapperVF64{pType} =
47+
IIPFunctionWrapper{Vector{Float64}, Vector{Float64}, pType, Float64}
48+
49+
"""
50+
wrapfun_iip_simple(ff, du, u, p, t)
51+
52+
Wrap an in-place function `ff(du, u, p, t) -> Nothing` into an [`IIPFunctionWrapper`](@ref)
53+
(single `FunctionWrapper`, no ForwardDiff support).
54+
55+
Unlike [`wrapfun_iip`](@ref), this function is **not** overridden by the ForwardDiff
56+
extension, so it avoids creating method-table backedges that cause invalidation.
57+
Use this when the solver does not need ForwardDiff internally.
58+
"""
59+
function wrapfun_iip_simple(ff, du, u, p, t)
60+
return FunctionWrappersWrappers.FunctionWrappersWrapper(
61+
Void(ff), (typeof((du, u, p, t)),), (Nothing,)
62+
)
63+
end
64+
2465
# Default dispatch assumes no ForwardDiff, gets added in the new dispatch
2566
function wrapfun_iip(ff, inputs)
2667
return FunctionWrappersWrappers.FunctionWrappersWrapper(

src/solve.jl

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -836,12 +836,7 @@ function promote_f(f::F, ::Val{specialize}, u0, p, t, ::Val{false}) where {F, sp
836836
!(f.f isa FunctionWrappersWrappers.FunctionWrappersWrapper)
837837
)
838838
)
839-
return unwrapped_f(
840-
f,
841-
FunctionWrappersWrappers.FunctionWrappersWrapper(
842-
Void(f.f), (typeof((u0, u0, p, t)),), (Nothing,)
843-
)
844-
)
839+
return unwrapped_f(f, wrapfun_iip_simple(f.f, u0, u0, p, t))
845840
else
846841
return f
847842
end

test/function_wrapper_aliases.jl

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
using DiffEqBase, ForwardDiff, Test
2+
using DiffEqBase: Void, FunctionWrappersWrappers, OrdinaryDiffEqTag,
3+
wrapfun_iip, wrapfun_iip_simple
4+
using FunctionWrappers: FunctionWrapper
5+
6+
# Get the ForwardDiff extension module
7+
const FDExt = Base.get_extension(DiffEqBase, :DiffEqBaseForwardDiffExt)
8+
9+
@testset "IIPFunctionWrapper type alias (no ForwardDiff)" begin
10+
ff = (du, u, p, t) -> (du .= u; nothing)
11+
du = zeros(3)
12+
u = ones(3)
13+
p = [1.0, 2.0]
14+
t = 0.0
15+
16+
# wrapfun_iip_simple should produce an IIPFunctionWrapper
17+
wrapped = wrapfun_iip_simple(ff, du, u, p, t)
18+
@test wrapped isa DiffEqBase.IIPFunctionWrapper
19+
@test wrapped isa DiffEqBase.IIPFunctionWrapper{
20+
Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64}
21+
22+
# VF64 alias should match
23+
@test wrapped isa DiffEqBase.IIPFunctionWrapperVF64{Vector{Float64}}
24+
25+
# The wrapped function should be callable and produce correct results
26+
du_test = zeros(3)
27+
wrapped.fw[1](du_test, u, p, t)
28+
@test du_test == u
29+
end
30+
31+
@testset "IIPFunctionWrapperVF64 with NullParameters" begin
32+
ff = (du, u, p, t) -> (du .= u; nothing)
33+
du = zeros(3)
34+
u = ones(3)
35+
p = SciMLBase.NullParameters()
36+
t = 0.0
37+
38+
wrapped = wrapfun_iip_simple(ff, du, u, p, t)
39+
@test wrapped isa DiffEqBase.IIPFunctionWrapper
40+
@test wrapped isa DiffEqBase.IIPFunctionWrapperVF64{SciMLBase.NullParameters}
41+
end
42+
43+
@testset "ODEDualTag and ODEDualType (ForwardDiff extension)" begin
44+
@test FDExt.ODEDualTag === ForwardDiff.Tag{OrdinaryDiffEqTag, Float64}
45+
@test FDExt.ODEDualType === ForwardDiff.Dual{
46+
ForwardDiff.Tag{OrdinaryDiffEqTag, Float64}, Float64, 1}
47+
end
48+
49+
@testset "IIPFunctionWrapperForwardDiff type alias" begin
50+
ff = (du, u, p, t) -> (du .= u; nothing)
51+
du = zeros(3)
52+
u = ones(3)
53+
p = [1.0, 2.0]
54+
t = 0.0
55+
56+
# wrapfun_iip with ForwardDiff loaded should produce a 4-wrapper
57+
wrapped = wrapfun_iip(ff, (du, u, p, t))
58+
@test wrapped isa FDExt.IIPFunctionWrapperForwardDiff
59+
60+
# VF64 alias should match
61+
@test wrapped isa FDExt.IIPFunctionWrapperForwardDiffVF64{Vector{Float64}}
62+
63+
# The wrapped function should be callable
64+
du_test = zeros(3)
65+
wrapped.fw[1](du_test, u, p, t)
66+
@test du_test == u
67+
end
68+
69+
@testset "IIPFunctionWrapperForwardDiffVF64 with NullParameters" begin
70+
ff = (du, u, p, t) -> (du .= u; nothing)
71+
72+
# Default wrapfun_iip (no args) produces the 7-wrapper variant, not 4-wrapper
73+
wrapped_default = wrapfun_iip(ff)
74+
# The default 7-wrapper has a different structure (7 entries, not 4),
75+
# so it should NOT match the 4-wrapper alias
76+
@test !(wrapped_default isa FDExt.IIPFunctionWrapperForwardDiff)
77+
78+
# But with explicit 4-tuple args and NullParameters, it should match
79+
du = zeros(3)
80+
u = ones(3)
81+
p = SciMLBase.NullParameters()
82+
t = 0.0
83+
wrapped = wrapfun_iip(ff, (du, u, p, t))
84+
@test wrapped isa FDExt.IIPFunctionWrapperForwardDiff
85+
@test wrapped isa FDExt.IIPFunctionWrapperForwardDiffVF64{SciMLBase.NullParameters}
86+
end
87+
88+
@testset "wrapfun_iip_simple does not change behavior with ForwardDiff loaded" begin
89+
ff = (du, u, p, t) -> (du .= u; nothing)
90+
du = zeros(3)
91+
u = ones(3)
92+
p = [1.0, 2.0]
93+
t = 0.0
94+
95+
# wrapfun_iip_simple should ALWAYS produce a 1-wrapper, even with ForwardDiff loaded
96+
wrapped_simple = wrapfun_iip_simple(ff, du, u, p, t)
97+
@test wrapped_simple isa DiffEqBase.IIPFunctionWrapper
98+
99+
# It should NOT match the ForwardDiff 4-wrapper alias
100+
@test !(wrapped_simple isa FDExt.IIPFunctionWrapperForwardDiff)
101+
102+
# wrapfun_iip should produce a 4-wrapper (ForwardDiff ext overrides)
103+
wrapped_fd = wrapfun_iip(ff, (du, u, p, t))
104+
@test wrapped_fd isa FDExt.IIPFunctionWrapperForwardDiff
105+
@test !(wrapped_fd isa DiffEqBase.IIPFunctionWrapper)
106+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ end
3636
@time @safetestset "Norm" include("norm.jl")
3737
@time @safetestset "Utils" include("utils.jl")
3838
@time @safetestset "ForwardDiff Dual Detection" include("forwarddiff_dual_detection.jl")
39+
@time @safetestset "FunctionWrapper Aliases" include("function_wrapper_aliases.jl")
3940
@time @safetestset "ODE default norm" include("ode_default_norm.jl")
4041
@time @safetestset "ODE default unstable check" include("ode_default_unstable_check.jl")
4142
@time @safetestset "Problem Kwargs Merging" include("problem_kwargs_merging.jl")

0 commit comments

Comments
 (0)