Skip to content

Commit c9b06fb

Browse files
WIP: Make ForwardDiff Optional
```julia @time using OrdinaryDiffEqTsit5 # 0.678347 seconds (806.13 k allocations: 60.106 MiB, 3.97% gc time, 7.38% compilation time) @time begin using OrdinaryDiffEqTsit5 function lorenz(du,u,p,t) du[1] = 10.0(u[2]-u[1]) du[2] = u[1]*(28.0-u[3]) - u[2] du[3] = u[1]*u[2] - (8/3)*u[3] end u0 = [1.0;0.0;0.0]; tspan = (0.0,100.0) prob = ODEProblem(lorenz,u0,tspan) solve(prob,Tsit5()) end # 0.765026 seconds (1.29 M allocations: 86.179 MiB, 2.16% gc time, 24.20% compilation time: 2% of which was recompilation) @time_imports using OrdinaryDiffEqTsit5 #= ┌ 0.0 ms DocStringExtensions.__init__() 1.4 ms DocStringExtensions 0.5 ms Reexport 7.6 ms Preferences 0.6 ms PrecompileTools 0.5 ms FastPower 1.2 ms ArrayInterface 0.9 ms StaticArraysCore 0.7 ms ArrayInterface → ArrayInterfaceStaticArraysCoreExt 5.6 ms FunctionWrappers 0.5 ms MuladdMacro 3.2 ms OrderedCollections 0.5 ms UnPack 0.7 ms Parameters 0.8 ms Statistics 0.4 ms IfElse 0.9 ms CommonWorldInvalidations 9.8 ms Static 0.8 ms Compat 0.5 ms Compat → CompatLinearAlgebraExt 12.2 ms StaticArrayInterface 1.0 ms ManualMemory ┌ 0.0 ms ThreadingUtilities.__init__() 4.7 ms ThreadingUtilities 0.7 ms SIMDTypes 2.0 ms LayoutPointers 2.9 ms CloseOpenIntervals 8.5 ms StrideArraysCore 0.6 ms BitTwiddlingConvenienceFunctions ┌ 0.0 ms CPUSummary.__init__() 1.6 ms CPUSummary ┌ 0.0 ms PolyesterWeave.__init__() 2.6 ms PolyesterWeave 0.8 ms Polyester 1.5 ms FastBroadcast 9.4 ms RecipesBase 0.8 ms ExprTools 1.1 ms Serialization 0.9 ms RuntimeGeneratedFunctions 7.2 ms MacroTools ┌ 0.0 ms InverseFunctions.__init__() 1.6 ms InverseFunctions 0.7 ms ConstructionBase 0.6 ms ConstructionBase → ConstructionBaseLinearAlgebraExt 0.6 ms CompositionsBase 0.6 ms CompositionsBase → CompositionsBaseInverseFunctionsExt 0.6 ms InverseFunctions → InverseFunctionsDatesExt ┌ 0.0 ms Accessors.__init__() 10.6 ms Accessors 0.7 ms Accessors → LinearAlgebraExt 2.9 ms SymbolicIndexingInterface 0.8 ms Adapt 0.7 ms DataValueInterfaces 1.0 ms DataAPI 0.7 ms IteratorInterfaceExtensions 0.7 ms TableTraits 8.2 ms Tables 0.9 ms GPUArraysCore 0.6 ms ArrayInterface → ArrayInterfaceGPUArraysCoreExt 17.1 ms RecursiveArrayTools 0.8 ms RecursiveArrayTools → RecursiveArrayToolsFastBroadcastExt ┌ 0.0 ms TruncatedStacktraces.__init__() 0.8 ms TruncatedStacktraces 6.6 ms Setfield 14.8 ms IrrationalConstants 1.2 ms DiffRules 1.3 ms DiffResults ┌ 2.1 ms OpenLibm_jll.__init__() 4.9 ms OpenLibm_jll 0.9 ms NaNMath 0.9 ms LogExpFunctions 0.8 ms JLLWrappers ┌ 12.0 ms CompilerSupportLibraries_jll.__init__() 15.4 ms CompilerSupportLibraries_jll ┌ 1.3 ms OpenSpecFun_jll.__init__() 2.3 ms OpenSpecFun_jll 5.7 ms SpecialFunctions 0.8 ms CommonSubexpressions 25.3 ms ForwardDiff 0.8 ms LogExpFunctions → LogExpFunctionsInverseFunctionsExt 0.7 ms FastPower → FastPowerForwardDiffExt 0.7 ms RecursiveArrayTools → RecursiveArrayToolsForwardDiffExt 0.9 ms EnumX 0.8 ms ConcreteStructs 0.9 ms FastClosures 1.1 ms PreallocationTools 0.8 ms FunctionWrappersWrappers 0.9 ms SciMLStructures ┌ 0.0 ms Distributed.__init__() 9.3 ms Distributed 0.8 ms CommonSolve 3.1 ms ADTypes 1.2 ms ADTypes → ADTypesConstructionBaseExt 75.6 ms MLStyle 20.54% compilation time 4.6 ms Expronicon 8.7 ms SciMLOperators 1.0 ms SciMLOperators → SciMLOperatorsStaticArraysCoreExt ┌ 0.0 ms SciMLBase.__init__() 76.9 ms SciMLBase 5.1 ms DiffEqBase 38.0 ms FillArrays 1.3 ms FillArrays → FillArraysStatisticsExt 0.9 ms SimpleUnPack 21.6 ms DataStructures 31.1 ms OrdinaryDiffEqCore 39.0 ms OrdinaryDiffEqTsit5 =# ``` This shows that ForwardDiff and its dependencies are a large part of the startup time, and that comes through DiffEqBase. With the DifferentiationInterface changes we will hopefully no longer default to ForwardDiff anyways and instead with Enzyme, and so we should make the stack not require ForwardDiff unless the user/solver wants it.
1 parent 6e6e63b commit c9b06fb

File tree

7 files changed

+187
-333
lines changed

7 files changed

+187
-333
lines changed

Project.toml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,13 @@ EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
1313
FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898"
1414
FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a"
1515
FastPower = "a4df4552-cc26-4903-aec0-212e50a0e84b"
16-
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1716
FunctionWrappers = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e"
1817
FunctionWrappersWrappers = "77dc65aa-8811-40c2-897b-53d922fa7daf"
1918
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
2019
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
2120
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
2221
MuladdMacro = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221"
2322
Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a"
24-
PreallocationTools = "d236fae5-4411-538c-8e31-a6e3d9e00b46"
2523
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
2624
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
2725
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
@@ -40,6 +38,7 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
4038
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
4139
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
4240
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
41+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
4342
GeneralizedGenerated = "6b9d7cbe-bcb9-11e9-073f-15a7a543e2eb"
4443
GTPSA = "b27dd330-f138-47c5-815b-40db9dd9b6e8"
4544
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
@@ -55,6 +54,7 @@ DiffEqBaseCUDAExt = "CUDA"
5554
DiffEqBaseChainRulesCoreExt = "ChainRulesCore"
5655
DiffEqBaseDistributionsExt = "Distributions"
5756
DiffEqBaseEnzymeExt = ["ChainRulesCore", "Enzyme"]
57+
DiffEqBaseForwardDiffExt = ["ForwardDiff"]
5858
DiffEqBaseGeneralizedGeneratedExt = "GeneralizedGenerated"
5959
DiffEqBaseGTPSAExt = "GTPSA"
6060
DiffEqBaseMPIExt = "MPI"
@@ -92,7 +92,6 @@ Measurements = "2"
9292
MonteCarloMeasurements = "1"
9393
MuladdMacro = "0.2.1"
9494
Parameters = "0.12.0"
95-
PreallocationTools = "0.4"
9695
PrecompileTools = "1"
9796
Printf = "1.9"
9897
RecursiveArrayTools = "3"

src/forwarddiff.jl renamed to ext/DiffEqBaseForwardDiffExt.jl

Lines changed: 158 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,99 @@
1+
module DiffEqBaseForwardDiffExt
2+
3+
using DiffEqBase, ForwardDiff
4+
using DiffeqBase: Void, FunctionWrappersWrappers, OrdinaryDiffEqTag
5+
import DiffEqBase: hasdualpromote, wrapfun_oop, wrapfun_iip, promote_u0
6+
17
const DUALCHECK_RECURSION_MAX = 10
28

9+
eltypedual(x) = eltype(x) <: ForwardDiff.Dual
10+
const dualT = ForwardDiff.Dual{ForwardDiff.Tag{OrdinaryDiffEqTag, Float64}, Float64, 1}
11+
dualgen(::Type{T}) where {T} = ForwardDiff.Dual{ForwardDiff.Tag{OrdinaryDiffEqTag, T}, T, 1}
12+
13+
hasdualpromote(u0,t::Number) = hasmethod(ArrayInterface.promote_eltype,
14+
Tuple{Type{typeof(u0)}, Type{dualgen(eltype(u0))}}) &&
15+
hasmethod(promote_rule,
16+
Tuple{Type{eltype(u0)}, Type{dualgen(eltype(u0))}}) &&
17+
hasmethod(promote_rule,
18+
Tuple{Type{eltype(u0)}, Type{typeof(t)}}))
19+
20+
const NORECOMPILE_IIP_SUPPORTED_ARGS = (
21+
Tuple{Vector{Float64}, Vector{Float64},
22+
Vector{Float64}, Float64},
23+
Tuple{Vector{Float64}, Vector{Float64},
24+
SciMLBase.NullParameters, Float64})
25+
26+
const oop_arglists = (Tuple{Vector{Float64}, Vector{Float64}, Float64},
27+
Tuple{Vector{Float64}, SciMLBase.NullParameters, Float64},
28+
Tuple{Vector{Float64}, Vector{Float64}, dualT},
29+
Tuple{Vector{dualT}, Vector{Float64}, Float64},
30+
Tuple{Vector{dualT}, SciMLBase.NullParameters, Float64},
31+
Tuple{Vector{Float64}, SciMLBase.NullParameters, dualT})
32+
33+
const NORECOMPILE_OOP_SUPPORTED_ARGS = (Tuple{Vector{Float64},
34+
Vector{Float64}, Float64},
35+
Tuple{Vector{Float64},
36+
SciMLBase.NullParameters, Float64})
37+
const oop_returnlists = (Vector{Float64}, Vector{Float64},
38+
ntuple(x -> Vector{dualT}, length(oop_arglists) - 2)...)
39+
40+
function wrapfun_oop(ff, inputs::Tuple = ())
41+
if !isempty(inputs)
42+
IT = Tuple{map(typeof, inputs)...}
43+
if IT NORECOMPILE_OOP_SUPPORTED_ARGS
44+
throw(NoRecompileArgumentError(IT))
45+
end
46+
end
47+
FunctionWrappersWrappers.FunctionWrappersWrapper(ff, oop_arglists,
48+
oop_returnlists)
49+
end
50+
51+
function wrapfun_iip(ff,
52+
inputs::Tuple{T1, T2, T3, T4}) where {T1, T2, T3, T4}
53+
T = eltype(T2)
54+
dualT = dualgen(T)
55+
dualT1 = ArrayInterface.promote_eltype(T1, dualT)
56+
dualT2 = ArrayInterface.promote_eltype(T2, dualT)
57+
dualT4 = dualgen(promote_type(T, T4))
58+
59+
iip_arglists = (Tuple{T1, T2, T3, T4},
60+
Tuple{dualT1, dualT2, T3, T4},
61+
Tuple{dualT1, T2, T3, dualT4},
62+
Tuple{dualT1, dualT2, T3, dualT4})
63+
64+
iip_returnlists = ntuple(x -> Nothing, 4)
65+
66+
fwt = map(iip_arglists, iip_returnlists) do A, R
67+
FunctionWrappersWrappers.FunctionWrappers.FunctionWrapper{R, A}(Void(ff))
68+
end
69+
FunctionWrappersWrappers.FunctionWrappersWrapper{typeof(fwt), false}(fwt)
70+
end
71+
72+
const iip_arglists_default = (
73+
Tuple{Vector{Float64}, Vector{Float64}, Vector{Float64},
74+
Float64},
75+
Tuple{Vector{Float64}, Vector{Float64},
76+
SciMLBase.NullParameters,
77+
Float64
78+
},
79+
Tuple{Vector{dualT}, Vector{Float64}, Vector{Float64}, dualT},
80+
Tuple{Vector{dualT}, Vector{dualT}, Vector{Float64}, dualT},
81+
Tuple{Vector{dualT}, Vector{dualT}, Vector{Float64}, Float64},
82+
Tuple{Vector{dualT}, Vector{dualT}, SciMLBase.NullParameters,
83+
Float64
84+
},
85+
Tuple{Vector{dualT}, Vector{Float64},
86+
SciMLBase.NullParameters, dualT
87+
})
88+
const iip_returnlists_default = ntuple(x -> Nothing, length(iip_arglists_default))
89+
90+
function wrapfun_iip(@nospecialize(ff))
91+
fwt = map(iip_arglists_default, iip_returnlists_default) do A, R
92+
FunctionWrappersWrappers.FunctionWrappers.FunctionWrapper{R, A}(Void(ff))
93+
end
94+
FunctionWrappersWrappers.FunctionWrappersWrapper{typeof(fwt), false}(fwt)
95+
end
96+
397
"""
498
promote_dual(::Type{T},::Type{T2})
599
@@ -371,8 +465,6 @@ end
371465
anyeltypedual(::@Kwargs{}, ::Type{Val{counter}} = Val{0}) where {counter} = Any
372466
anyeltypedual(::Type{@Kwargs{}}, ::Type{Val{counter}} = Val{0}) where {counter} = Any
373467

374-
@inline promote_u0(::Nothing, p, t0) = nothing
375-
376468
@inline function promote_u0(u0, p, t0)
377469
if SciMLStructures.isscimlstructure(p)
378470
_p = SciMLStructures.canonicalize(SciMLStructures.Tunable(), p)[1]
@@ -482,6 +574,69 @@ end
482574

483575
# Static Arrays don't support the `init` keyword argument for `sum`
484576
@inline __sum(f::F, args...; init, kwargs...) where {F} = sum(f, args...; init, kwargs...)
485-
@inline function __sum(f::F, a::StaticArraysCore.StaticArray...; init, kwargs...) where {F}
577+
@inline function __sum(f::F, a::DiffEqBase.StaticArraysCore.StaticArray...; init, kwargs...) where {F}
486578
return mapreduce(f, +, a...; init, kwargs...)
487579
end
580+
581+
# Differentiation of internal solver
582+
583+
function scalar_nlsolve_ad(prob, alg::InternalITP, args...; kwargs...)
584+
f = prob.f
585+
p = value(prob.p)
586+
587+
if prob isa IntervalNonlinearProblem
588+
tspan = value(prob.tspan)
589+
newprob = IntervalNonlinearProblem(f, tspan, p; prob.kwargs...)
590+
else
591+
u0 = value(prob.u0)
592+
newprob = NonlinearProblem(f, u0, p; prob.kwargs...)
593+
end
594+
595+
sol = solve(newprob, alg, args...; kwargs...)
596+
597+
uu = sol.u
598+
if p isa Number
599+
f_p = ForwardDiff.derivative(Base.Fix1(f, uu), p)
600+
else
601+
f_p = ForwardDiff.gradient(Base.Fix1(f, uu), p)
602+
end
603+
604+
f_x = ForwardDiff.derivative(Base.Fix2(f, p), uu)
605+
pp = prob.p
606+
sumfun = let f_x′ = -f_x
607+
((fp, p),) -> (fp / f_x′) * ForwardDiff.partials(p)
608+
end
609+
partials = sum(sumfun, zip(f_p, pp))
610+
return sol, partials
611+
end
612+
613+
function SciMLBase.solve(
614+
prob::IntervalNonlinearProblem{uType, iip,
615+
<:ForwardDiff.Dual{T, V, P}},
616+
alg::InternalITP, args...;
617+
kwargs...) where {uType, iip, T, V, P}
618+
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
619+
return SciMLBase.build_solution(prob, alg, ForwardDiff.Dual{T, V, P}(sol.u, partials),
620+
sol.resid; retcode = sol.retcode,
621+
left = ForwardDiff.Dual{T, V, P}(sol.left, partials),
622+
right = ForwardDiff.Dual{T, V, P}(sol.right, partials))
623+
end
624+
625+
function SciMLBase.solve(
626+
prob::IntervalNonlinearProblem{uType, iip,
627+
<:AbstractArray{
628+
<:ForwardDiff.Dual{T,
629+
V,
630+
P},
631+
}},
632+
alg::InternalITP, args...;
633+
kwargs...) where {uType, iip, T, V, P}
634+
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
635+
636+
return SciMLBase.build_solution(prob, alg, ForwardDiff.Dual{T, V, P}(sol.u, partials),
637+
sol.resid; retcode = sol.retcode,
638+
left = ForwardDiff.Dual{T, V, P}(sol.left, partials),
639+
right = ForwardDiff.Dual{T, V, P}(sol.right, partials))
640+
end
641+
642+
end

src/DiffEqBase.jl

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,18 +32,13 @@ import TruncatedStacktraces
3232

3333
using Setfield
3434

35-
using ForwardDiff
36-
3735
using EnumX
3836

3937
using Markdown
4038

4139
using ConcreteStructs: @concrete
4240
using FastClosures: @closure
4341

44-
# Could be made optional/glue
45-
import PreallocationTools
46-
4742
import FunctionWrappersWrappers
4843

4944
using SciMLBase
@@ -111,6 +106,13 @@ Reexport.@reexport using SciMLBase
111106

112107
SciMLBase.isfunctionwrapper(x::FunctionWrapper) = true
113108

109+
## Extension Functions
110+
111+
eltypedual(x) = false
112+
promote_u0(::Nothing, p, t0) = nothing
113+
114+
## Types
115+
114116
"""
115117
$(TYPEDEF)
116118
"""
@@ -132,7 +134,6 @@ include("utils.jl")
132134
include("stats.jl")
133135
include("calculate_residuals.jl")
134136
include("tableaus.jl")
135-
include("internal_falsi.jl")
136137
include("internal_itp.jl")
137138

138139
include("callbacks.jl")

0 commit comments

Comments
 (0)