Skip to content

Commit 18ff0f7

Browse files
Add wrapper structs for IIP FunctionWrappers to reduce stack trace length
Replace type aliases with actual wrapper structs that hide the verbose FunctionWrappersWrapper type parameters from stack traces. Before (1,077 chars for ForwardDiff wrapper): FunctionWrappersWrappers.FunctionWrappersWrapper{Tuple{FunctionWrappers.FunctionWrapper{Nothing, Tuple{...}}, ...}, false} After (~350 chars): DiffEqBase.DEIIPFunctionWrapperForwardDiff{Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64, Vector{Dual{...}}, Vector{Dual{...}}, Dual{...}} New types (exported from DiffEqBase): - DEIIPFunctionWrapper{duType, uType, pType, tType}: 1-wrapper struct for solvers without ForwardDiff (e.g. Tsit5, Verner) - DEIIPFunctionWrapperVF64{pType}: VF64-specialized alias - DEIIPFunctionWrapperForwardDiff{T1,T2,T3,T4,dT1,dT2,dT4}: 4-wrapper struct for ForwardDiff-aware solvers (Rosenbrock, implicit methods) - AnyFunctionWrapper: Union type for isa checks - wrapfun_iip_simple(ff, du, u, p, t): constructor that avoids ForwardDiff extension backedges ForwardDiff extension provides: - DEIIPFunctionWrapperForwardDiffVF64{pType}: VF64-specialized alias - ODEDualTag, ODEDualType: named aliases for ODE dual types Addresses DifferentialEquations.jl#1128. Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com> Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 6a58ddc commit 18ff0f7

File tree

6 files changed

+293
-11
lines changed

6 files changed

+293
-11
lines changed

ext/DiffEqBaseForwardDiffExt.jl

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,71 @@ module DiffEqBaseForwardDiffExt
33
using DiffEqBase, ForwardDiff
44
using DiffEqBase.ArrayInterface
55
using DiffEqBase: Void, FunctionWrappersWrappers, OrdinaryDiffEqTag,
6+
DEIIPFunctionWrapperForwardDiff,
67
AbstractTimeseriesSolution,
78
RecursiveArrayTools, reduce_tup, _promote_tspan, has_continuous_callback
9+
using FunctionWrappers: FunctionWrapper
810
import DiffEqBase: hasdualpromote, wrapfun_oop, wrapfun_iip, prob2dtmin,
911
promote_tspan, ODE_DEFAULT_NORM
1012
import SciMLBase: isdualtype, DualEltypeChecker, sse, __sum
1113

1214
const dualT = ForwardDiff.Dual{ForwardDiff.Tag{OrdinaryDiffEqTag, Float64}, Float64, 1}
1315
dualgen(::Type{T}) where {T} = ForwardDiff.Dual{ForwardDiff.Tag{OrdinaryDiffEqTag, T}, T, 1}
1416

17+
"""
18+
ODEDualTag
19+
20+
Type alias for the ForwardDiff tag used by ODE solvers:
21+
`ForwardDiff.Tag{OrdinaryDiffEqTag, Float64}`.
22+
23+
Access from downstream packages via:
24+
```julia
25+
ext = Base.get_extension(DiffEqBase, :DiffEqBaseForwardDiffExt)
26+
ext.ODEDualTag
27+
```
28+
"""
29+
const ODEDualTag = ForwardDiff.Tag{OrdinaryDiffEqTag, Float64}
30+
31+
"""
32+
ODEDualType
33+
34+
Type alias for the ForwardDiff dual number used by ODE solvers:
35+
`ForwardDiff.Dual{ODEDualTag, Float64, 1}`.
36+
37+
Access from downstream packages via:
38+
```julia
39+
ext = Base.get_extension(DiffEqBase, :DiffEqBaseForwardDiffExt)
40+
ext.ODEDualType
41+
```
42+
"""
43+
const ODEDualType = ForwardDiff.Dual{ODEDualTag, Float64, 1}
44+
45+
"""
46+
DEIIPFunctionWrapperForwardDiffVF64{pType}
47+
48+
VF64-specialized alias for `DEIIPFunctionWrapperForwardDiff` matching the common
49+
in-place `Vector{Float64}` ODE case with ForwardDiff support.
50+
51+
Equivalent to:
52+
```julia
53+
DEIIPFunctionWrapperForwardDiff{
54+
Vector{Float64}, Vector{Float64}, pType, Float64,
55+
Vector{ODEDualType}, Vector{ODEDualType}, ODEDualType,
56+
}
57+
```
58+
59+
Access from downstream packages via:
60+
```julia
61+
ext = Base.get_extension(DiffEqBase, :DiffEqBaseForwardDiffExt)
62+
ext.DEIIPFunctionWrapperForwardDiffVF64
63+
```
64+
"""
65+
const DEIIPFunctionWrapperForwardDiffVF64{pType} =
66+
DEIIPFunctionWrapperForwardDiff{
67+
Vector{Float64}, Vector{Float64}, pType, Float64,
68+
Vector{dualT}, Vector{dualT}, dualT,
69+
}
70+
1571
const NORECOMPILE_IIP_SUPPORTED_ARGS = (
1672
Tuple{
1773
Vector{Float64}, Vector{Float64},
@@ -82,7 +138,8 @@ function wrapfun_iip(
82138
fwt = map(iip_arglists, iip_returnlists) do A, R
83139
FunctionWrappersWrappers.FunctionWrappers.FunctionWrapper{R, A}(Void(ff))
84140
end
85-
return FunctionWrappersWrappers.FunctionWrappersWrapper{typeof(fwt), false}(fwt)
141+
inner = FunctionWrappersWrappers.FunctionWrappersWrapper{typeof(fwt), false}(fwt)
142+
return DEIIPFunctionWrapperForwardDiff(inner)
86143
end
87144

88145
const iip_arglists_default = (

src/DiffEqBase.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,9 @@ include("common_defaults.jl")
149149
include("solve.jl")
150150
include("internal_euler.jl")
151151
include("norecompile.jl")
152+
# unwrapped_f support for DE wrapper structs (delegates to inner FunctionWrappersWrapper)
153+
unwrapped_f(f::DEIIPFunctionWrapper) = unwrapped_f(f.fw)
154+
unwrapped_f(f::DEIIPFunctionWrapperForwardDiff) = unwrapped_f(f.fw)
152155
include("integrator_accessors.jl")
153156

154157
# This is only used for oop stiff solvers
@@ -175,6 +178,10 @@ export initialize!, finalize!
175178

176179
export SensitivityADPassThrough
177180

181+
# FunctionWrapper structs and aliases for the VF64 pattern
182+
export DEIIPFunctionWrapper, DEIIPFunctionWrapperVF64,
183+
DEIIPFunctionWrapperForwardDiff, AnyFunctionWrapper, wrapfun_iip_simple
184+
178185
include("precompilation.jl")
179186

180187
end # module

src/norecompile.jl

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,98 @@ function unwrap_fw(fw::FunctionWrapper)
2121
return fw.obj[]
2222
end
2323

24+
"""
25+
DEIIPFunctionWrapper{duType, uType, pType, tType}
26+
27+
Wrapper struct around a `FunctionWrappersWrapper` containing a single `FunctionWrapper`
28+
for an in-place function `f!(du, u, p, t) -> Nothing`.
29+
30+
Compared to a raw `FunctionWrappersWrapper`, this struct exposes only 4 type
31+
parameters (du, u, p, t types) instead of the full nested
32+
`FunctionWrappersWrapper{Tuple{FunctionWrapper{...}}, false}` type, which
33+
significantly reduces type string length in stack traces.
34+
35+
Used by solvers that do **not** require ForwardDiff internally (e.g. Tsit5, Verner).
36+
See also [`DEIIPFunctionWrapperForwardDiff`](@ref) for the ForwardDiff-aware variant.
37+
"""
38+
struct DEIIPFunctionWrapper{duType, uType, pType, tType}
39+
fw::FunctionWrappersWrappers.FunctionWrappersWrapper{
40+
Tuple{FunctionWrapper{Nothing, Tuple{duType, uType, pType, tType}}},
41+
false,
42+
}
43+
end
44+
45+
(f::DEIIPFunctionWrapper)(args...) = f.fw(args...)
46+
SciMLBase.isfunctionwrapper(::DEIIPFunctionWrapper) = true
47+
48+
"""
49+
DEIIPFunctionWrapperVF64{pType}
50+
51+
VF64-specialized alias: `DEIIPFunctionWrapper{Vector{Float64}, Vector{Float64}, pType, Float64}`.
52+
Matches the wrapper produced for the common in-place `Vector{Float64}` ODE case
53+
when ForwardDiff is **not** used by the solver.
54+
"""
55+
const DEIIPFunctionWrapperVF64{pType} =
56+
DEIIPFunctionWrapper{Vector{Float64}, Vector{Float64}, pType, Float64}
57+
58+
"""
59+
DEIIPFunctionWrapperForwardDiff{T1, T2, T3, T4, dT1, dT2, dT4}
60+
61+
Wrapper struct around a `FunctionWrappersWrapper` containing 4 `FunctionWrapper`
62+
entries for an in-place function `f!(du, u, p, t) -> Nothing` with ForwardDiff support.
63+
64+
The 4 wrappers cover:
65+
1. Base types: `(T1, T2, T3, T4)`
66+
2. Dual state: `(dT1, dT2, T3, T4)`
67+
3. Dual time: `(dT1, T2, T3, dT4)`
68+
4. Dual state+time: `(dT1, dT2, T3, dT4)`
69+
70+
Compared to a raw `FunctionWrappersWrapper`, this struct exposes 7 type parameters
71+
instead of repeating the full `FunctionWrapper{Nothing, Tuple{...}}` 4 times,
72+
significantly reducing type string length in stack traces.
73+
74+
Used by solvers that require ForwardDiff internally (e.g. Rosenbrock, implicit methods).
75+
See also [`DEIIPFunctionWrapper`](@ref) for the simpler non-ForwardDiff variant.
76+
"""
77+
struct DEIIPFunctionWrapperForwardDiff{T1, T2, T3, T4, dT1, dT2, dT4}
78+
fw::FunctionWrappersWrappers.FunctionWrappersWrapper{
79+
Tuple{
80+
FunctionWrapper{Nothing, Tuple{T1, T2, T3, T4}},
81+
FunctionWrapper{Nothing, Tuple{dT1, dT2, T3, T4}},
82+
FunctionWrapper{Nothing, Tuple{dT1, T2, T3, dT4}},
83+
FunctionWrapper{Nothing, Tuple{dT1, dT2, T3, dT4}},
84+
},
85+
false,
86+
}
87+
end
88+
89+
(f::DEIIPFunctionWrapperForwardDiff)(args...) = f.fw(args...)
90+
SciMLBase.isfunctionwrapper(::DEIIPFunctionWrapperForwardDiff) = true
91+
92+
# Union for isa checks (avoid double-wrapping)
93+
const AnyFunctionWrapper = Union{
94+
FunctionWrappersWrappers.FunctionWrappersWrapper,
95+
DEIIPFunctionWrapper,
96+
DEIIPFunctionWrapperForwardDiff,
97+
}
98+
99+
"""
100+
wrapfun_iip_simple(ff, du, u, p, t)
101+
102+
Wrap an in-place function `ff(du, u, p, t) -> Nothing` into a [`DEIIPFunctionWrapper`](@ref)
103+
(single `FunctionWrapper`, no ForwardDiff support).
104+
105+
Unlike [`wrapfun_iip`](@ref), this function is **not** overridden by the ForwardDiff
106+
extension, so it avoids creating method-table backedges that cause invalidation.
107+
Use this when the solver does not need ForwardDiff internally.
108+
"""
109+
function wrapfun_iip_simple(ff, du, u, p, t)
110+
inner = FunctionWrappersWrappers.FunctionWrappersWrapper(
111+
Void(ff), (typeof((du, u, p, t)),), (Nothing,)
112+
)
113+
return DEIIPFunctionWrapper(inner)
114+
end
115+
24116
# Default dispatch assumes no ForwardDiff, gets added in the new dispatch
25117
function wrapfun_iip(ff, inputs)
26118
return FunctionWrappersWrappers.FunctionWrappersWrapper(

src/solve.jl

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -790,16 +790,16 @@ function promote_f(f::F, ::Val{specialize}, u0, p, t, ::Val{true}) where {F, spe
790790
) ||
791791
(
792792
specialize === SciMLBase.FunctionWrapperSpecialize &&
793-
!(f.f isa FunctionWrappersWrappers.FunctionWrappersWrapper)
793+
!(f.f isa AnyFunctionWrapper)
794794
)
795795
)
796796
# Wrap tgrad if present, so its type is also erased.
797797
# tgrad!(dT, u, p, t) -> Nothing has the same shape as the RHS.
798-
if f.tgrad !== nothing && !(f.tgrad isa FunctionWrappersWrappers.FunctionWrappersWrapper)
798+
if f.tgrad !== nothing && !(f.tgrad isa AnyFunctionWrapper)
799799
f = @set f.tgrad = wrapfun_jac_iip(f.tgrad, (u0, u0, p, t))
800800
end
801801
# Wrap the Jacobian if present, so its type is also erased
802-
if f.jac !== nothing && !(f.jac isa FunctionWrappersWrappers.FunctionWrappersWrapper)
802+
if f.jac !== nothing && !(f.jac isa AnyFunctionWrapper)
803803
n = length(u0)
804804
J_proto = f.jac_prototype !== nothing ? similar(f.jac_prototype, uElType) :
805805
zeros(uElType, n, n)
@@ -833,15 +833,10 @@ function promote_f(f::F, ::Val{specialize}, u0, p, t, ::Val{false}) where {F, sp
833833
) ||
834834
(
835835
specialize === SciMLBase.FunctionWrapperSpecialize &&
836-
!(f.f isa FunctionWrappersWrappers.FunctionWrappersWrapper)
837-
)
838-
)
839-
return unwrapped_f(
840-
f,
841-
FunctionWrappersWrappers.FunctionWrappersWrapper(
842-
Void(f.f), (typeof((u0, u0, p, t)),), (Nothing,)
836+
!(f.f isa AnyFunctionWrapper)
843837
)
844838
)
839+
return unwrapped_f(f, wrapfun_iip_simple(f.f, u0, u0, p, t))
845840
else
846841
return f
847842
end

test/function_wrapper_aliases.jl

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
using DiffEqBase, ForwardDiff, Test
2+
using DiffEqBase: Void, FunctionWrappersWrappers, OrdinaryDiffEqTag,
3+
wrapfun_iip, wrapfun_iip_simple, AnyFunctionWrapper
4+
using FunctionWrappers: FunctionWrapper
5+
6+
# Get the ForwardDiff extension module
7+
const FDExt = Base.get_extension(DiffEqBase, :DiffEqBaseForwardDiffExt)
8+
9+
@testset "DEIIPFunctionWrapper struct (no ForwardDiff)" begin
10+
ff = (du, u, p, t) -> (du .= u; nothing)
11+
du = zeros(3)
12+
u = ones(3)
13+
p = [1.0, 2.0]
14+
t = 0.0
15+
16+
# wrapfun_iip_simple should produce a DEIIPFunctionWrapper
17+
wrapped = wrapfun_iip_simple(ff, du, u, p, t)
18+
@test wrapped isa DiffEqBase.DEIIPFunctionWrapper
19+
@test wrapped isa DiffEqBase.DEIIPFunctionWrapper{
20+
Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64}
21+
22+
# VF64 alias should match
23+
@test wrapped isa DiffEqBase.DEIIPFunctionWrapperVF64{Vector{Float64}}
24+
25+
# Should match AnyFunctionWrapper union
26+
@test wrapped isa AnyFunctionWrapper
27+
28+
# The wrapped function should be callable and produce correct results
29+
du_test = zeros(3)
30+
wrapped(du_test, u, p, t)
31+
@test du_test == u
32+
33+
# isfunctionwrapper should return true
34+
@test SciMLBase.isfunctionwrapper(wrapped)
35+
36+
# Stack trace type string should be short
37+
type_str = string(typeof(wrapped))
38+
@test occursin("DEIIPFunctionWrapper", type_str)
39+
@test !occursin("FunctionWrappersWrapper", type_str)
40+
end
41+
42+
@testset "DEIIPFunctionWrapperVF64 with NullParameters" begin
43+
ff = (du, u, p, t) -> (du .= u; nothing)
44+
du = zeros(3)
45+
u = ones(3)
46+
p = SciMLBase.NullParameters()
47+
t = 0.0
48+
49+
wrapped = wrapfun_iip_simple(ff, du, u, p, t)
50+
@test wrapped isa DiffEqBase.DEIIPFunctionWrapper
51+
@test wrapped isa DiffEqBase.DEIIPFunctionWrapperVF64{SciMLBase.NullParameters}
52+
end
53+
54+
@testset "ODEDualTag and ODEDualType (ForwardDiff extension)" begin
55+
@test FDExt.ODEDualTag === ForwardDiff.Tag{OrdinaryDiffEqTag, Float64}
56+
@test FDExt.ODEDualType === ForwardDiff.Dual{
57+
ForwardDiff.Tag{OrdinaryDiffEqTag, Float64}, Float64, 1}
58+
end
59+
60+
@testset "DEIIPFunctionWrapperForwardDiff struct" begin
61+
ff = (du, u, p, t) -> (du .= u; nothing)
62+
du = zeros(3)
63+
u = ones(3)
64+
p = [1.0, 2.0]
65+
t = 0.0
66+
67+
# wrapfun_iip with ForwardDiff loaded should produce a DEIIPFunctionWrapperForwardDiff
68+
wrapped = wrapfun_iip(ff, (du, u, p, t))
69+
@test wrapped isa DiffEqBase.DEIIPFunctionWrapperForwardDiff
70+
71+
# VF64 alias should match
72+
@test wrapped isa FDExt.DEIIPFunctionWrapperForwardDiffVF64{Vector{Float64}}
73+
74+
# Should match AnyFunctionWrapper union
75+
@test wrapped isa AnyFunctionWrapper
76+
77+
# The wrapped function should be callable
78+
du_test = zeros(3)
79+
wrapped(du_test, u, p, t)
80+
@test du_test == u
81+
82+
# isfunctionwrapper should return true
83+
@test SciMLBase.isfunctionwrapper(wrapped)
84+
85+
# Stack trace type string should NOT contain FunctionWrappersWrapper
86+
type_str = string(typeof(wrapped))
87+
@test occursin("DEIIPFunctionWrapperForwardDiff", type_str)
88+
@test !occursin("FunctionWrappersWrapper", type_str)
89+
end
90+
91+
@testset "DEIIPFunctionWrapperForwardDiffVF64 with NullParameters" begin
92+
ff = (du, u, p, t) -> (du .= u; nothing)
93+
94+
# Default wrapfun_iip (no args) produces the 7-wrapper variant, not 4-wrapper
95+
wrapped_default = wrapfun_iip(ff)
96+
# The default 7-wrapper has a different structure (7 entries, not 4),
97+
# so it should NOT be a DEIIPFunctionWrapperForwardDiff
98+
@test !(wrapped_default isa DiffEqBase.DEIIPFunctionWrapperForwardDiff)
99+
# But it IS still an AnyFunctionWrapper (raw FunctionWrappersWrapper)
100+
@test wrapped_default isa AnyFunctionWrapper
101+
102+
# With explicit 4-tuple args and NullParameters, it should match
103+
du = zeros(3)
104+
u = ones(3)
105+
p = SciMLBase.NullParameters()
106+
t = 0.0
107+
wrapped = wrapfun_iip(ff, (du, u, p, t))
108+
@test wrapped isa DiffEqBase.DEIIPFunctionWrapperForwardDiff
109+
@test wrapped isa FDExt.DEIIPFunctionWrapperForwardDiffVF64{SciMLBase.NullParameters}
110+
end
111+
112+
@testset "wrapfun_iip_simple does not change behavior with ForwardDiff loaded" begin
113+
ff = (du, u, p, t) -> (du .= u; nothing)
114+
du = zeros(3)
115+
u = ones(3)
116+
p = [1.0, 2.0]
117+
t = 0.0
118+
119+
# wrapfun_iip_simple should ALWAYS produce a DEIIPFunctionWrapper, even with ForwardDiff loaded
120+
wrapped_simple = wrapfun_iip_simple(ff, du, u, p, t)
121+
@test wrapped_simple isa DiffEqBase.DEIIPFunctionWrapper
122+
123+
# It should NOT match the ForwardDiff 4-wrapper struct
124+
@test !(wrapped_simple isa DiffEqBase.DEIIPFunctionWrapperForwardDiff)
125+
126+
# wrapfun_iip should produce a DEIIPFunctionWrapperForwardDiff (ForwardDiff ext wraps it)
127+
wrapped_fd = wrapfun_iip(ff, (du, u, p, t))
128+
@test wrapped_fd isa DiffEqBase.DEIIPFunctionWrapperForwardDiff
129+
@test !(wrapped_fd isa DiffEqBase.DEIIPFunctionWrapper)
130+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ end
3636
@time @safetestset "Norm" include("norm.jl")
3737
@time @safetestset "Utils" include("utils.jl")
3838
@time @safetestset "ForwardDiff Dual Detection" include("forwarddiff_dual_detection.jl")
39+
@time @safetestset "FunctionWrapper Aliases" include("function_wrapper_aliases.jl")
3940
@time @safetestset "ODE default norm" include("ode_default_norm.jl")
4041
@time @safetestset "ODE default unstable check" include("ode_default_unstable_check.jl")
4142
@time @safetestset "Problem Kwargs Merging" include("problem_kwargs_merging.jl")

0 commit comments

Comments
 (0)