Skip to content

Commit df3a3f6

Browse files
committed
Specialize more functions on f parameter for trim compatibility
Similar to SciML#2854, this PR adds type specialization for the f parameter in several functions across OrdinaryDiffEq.jl to improve compatibility with --trim and reduce dynamic dispatch. Functions specialized: - jacobian scalar fallback in derivative_wrappers.jl - sparsity_colorvec in derivative_wrappers.jl - WOperator constructor in derivative_utils.jl - islinearfunction in derivative_utils.jl - _compute_rhs in newton.jl - DAEResidualJacobianWrapper constructor in utils.jl - Interpolation functions in generic_dense.jl - verify_f2 functions in symplectic_perform_step.jl These functions all either call f directly, access fields of f, pass f to other functions, or perform type checks on f, so specialization enables better compiler optimizations.
1 parent a4ad621 commit df3a3f6

File tree

6 files changed

+22
-22
lines changed

6 files changed

+22
-22
lines changed

lib/OrdinaryDiffEqCore/src/dense/generic_dense.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -477,17 +477,17 @@ end
477477
return expr
478478
end
479479

480-
function _evaluate_interpolant(f, Θ, dt, timeseries, i₋, i₊,
480+
function _evaluate_interpolant(f::F, Θ, dt, timeseries, i₋, i₊,
481481
cache, idxs,
482-
deriv, ks, ts, p, differential_vars)
482+
deriv, ks, ts, p, differential_vars) where F
483483
_ode_addsteps!(ks[i₊], ts[i₋], timeseries[i₋], timeseries[i₊], dt, f, p,
484484
cache) # update the kcurrent
485485
return ode_interpolant(Θ, dt, timeseries[i₋], timeseries[i₊], ks[i₊],
486486
cache, idxs, deriv, differential_vars)
487487
end
488-
function evaluate_composite_cache(f, Θ, dt, timeseries, i₋, i₊,
488+
function evaluate_composite_cache(f::F, Θ, dt, timeseries, i₋, i₊,
489489
caches::Tuple{C1, C2, Vararg}, idxs,
490-
deriv, ks, ts, p, cacheid, differential_vars) where {C1, C2}
490+
deriv, ks, ts, p, cacheid, differential_vars) where {F, C1, C2}
491491
if (cacheid -= 1) != 0
492492
return evaluate_composite_cache(f, Θ, dt, timeseries, i₋, i₊, Base.tail(caches),
493493
idxs,
@@ -497,16 +497,16 @@ function evaluate_composite_cache(f, Θ, dt, timeseries, i₋, i₊,
497497
first(caches), idxs,
498498
deriv, ks, ts, p, differential_vars)
499499
end
500-
function evaluate_composite_cache(f, Θ, dt, timeseries, i₋, i₊,
500+
function evaluate_composite_cache(f::F, Θ, dt, timeseries, i₋, i₊,
501501
caches::Tuple{C}, idxs,
502-
deriv, ks, ts, p, _, differential_vars) where {C}
502+
deriv, ks, ts, p, _, differential_vars) where {F, C}
503503
_evaluate_interpolant(f, Θ, dt, timeseries, i₋, i₊,
504504
only(caches), idxs,
505505
deriv, ks, ts, p, differential_vars)
506506
end
507507

508-
function evaluate_default_cache(f, Θ, dt, timeseries, i₋, i₊,
509-
cache::DefaultCache, idxs, deriv, ks, ts, p, cacheid, differential_vars)
508+
function evaluate_default_cache(f::F, Θ, dt, timeseries, i₋, i₊,
509+
cache::DefaultCache, idxs, deriv, ks, ts, p, cacheid, differential_vars) where F
510510
if cacheid == 1
511511
return _evaluate_interpolant(f, Θ, dt, timeseries, i₋, i₊,
512512
cache.cache1, idxs, deriv, ks, ts, p, differential_vars)
@@ -528,8 +528,8 @@ function evaluate_default_cache(f, Θ, dt, timeseries, i₋, i₊,
528528
end
529529
end
530530

531-
function evaluate_interpolant(f, Θ, dt, timeseries, i₋, i₊, cache, idxs,
532-
deriv, ks, ts, id, p, differential_vars)
531+
function evaluate_interpolant(f::F, Θ, dt, timeseries, i₋, i₊, cache, idxs,
532+
deriv, ks, ts, id, p, differential_vars) where F
533533
if isdiscretecache(cache)
534534
return ode_interpolant(Θ, dt, timeseries[i₋], timeseries[i₊], 0, cache, idxs,
535535
deriv, differential_vars)

lib/OrdinaryDiffEqDifferentiation/src/derivative_utils.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ mutable struct WOperator{IIP, T,
324324
jacvec)
325325
end
326326
end
327-
function WOperator{IIP}(f, u, gamma) where {IIP}
327+
function WOperator{IIP}(f::F, u, gamma) where {IIP, F}
328328
if isa(f, Union{SplitFunction, DynamicalODEFunction})
329329
error("WOperator does not support $(typeof(f)) yet")
330330
end
@@ -440,7 +440,7 @@ islinearfunction(integrator) = islinearfunction(integrator.f, integrator.alg)
440440
441441
return the tuple `(is_linear_wrt_odealg, islinearodefunction)`.
442442
"""
443-
function islinearfunction(f, alg)::Tuple{Bool, Bool}
443+
function islinearfunction(f::F, alg)::Tuple{Bool, Bool} where F
444444
isode = f isa ODEFunction && islinear(f.f)
445445
islin = isode || (issplit(alg) && f isa SplitFunction && islinear(f.f1.f))
446446
return islin, isode

lib/OrdinaryDiffEqDifferentiation/src/derivative_wrappers.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ function jacobian(f::F, x::AbstractArray{<:Number}, integrator) where F
136136
end
137137

138138
# fallback for scalar x, is needed for calc_J to work
139-
function jacobian(f, x, integrator)
139+
function jacobian(f::F, x, integrator) where F
140140
alg = unwrap_alg(integrator, true)
141141

142142
dense = ADTypes.dense_ad(alg_autodiff(alg))
@@ -393,7 +393,7 @@ function build_grad_config(alg, f::F1, tf::F2, du1, t) where {F1, F2}
393393
end
394394
end
395395

396-
function sparsity_colorvec(f, x)
396+
function sparsity_colorvec(f::F, x) where F
397397
sparsity = f.sparsity
398398

399399
if is_sparse_csc(sparsity)

lib/OrdinaryDiffEqNonlinearSolve/src/newton.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ function compute_ustep!(ustep, tmp, γ, z, method)
330330
ustep
331331
end
332332

333-
function _compute_rhs(tmp, γ, α, tstep, invγdt, method::MethodType, p, dt, f, z)
333+
function _compute_rhs(tmp, γ, α, tstep, invγdt, method::MethodType, p, dt, f::F, z) where F
334334
mass_matrix = f.mass_matrix
335335
ustep = compute_ustep(tmp, γ, z, method)
336336
if method === COEFFICIENT_MULTISTEP

lib/OrdinaryDiffEqNonlinearSolve/src/utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ mutable struct DAEResidualJacobianWrapper{isAD, F, pType, duType, uType, alphaTy
7171
tmp::tmpType
7272
uprev::uprevType
7373
t::tType
74-
function DAEResidualJacobianWrapper(alg, f, p, α, invγdt, tmp, uprev, t)
74+
function DAEResidualJacobianWrapper(alg, f::F, p, α, invγdt, tmp, uprev, t) where F
7575
ad = ADTypes.dense_ad(alg_autodiff(alg))
7676
isautodiff = ad isa AutoForwardDiff
7777
if isautodiff

lib/OrdinaryDiffEqSymplecticRK/src/symplectic_perform_step.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -77,22 +77,22 @@ end
7777
# f.f2(p, q, pa, t) = p which is the Newton/Lagrange equations
7878
# If called with different functions (which are possible in the Hamiltonian case)
7979
# an exception is thrown to avoid silently calculate wrong results.
80-
function verify_f2(f, p, q, pa, t, ::Any,
81-
::C) where {C <: Union{HamiltonConstantCache, VerletLeapfrogConstantCache,
80+
function verify_f2(f::F, p, q, pa, t, ::Any,
81+
::C) where {F, C <: Union{HamiltonConstantCache, VerletLeapfrogConstantCache,
8282
LeapfrogDriftKickDriftConstantCache}}
8383
f(p, q, pa, t)
8484
end
85-
function verify_f2(f, res, p, q, pa, t, ::Any,
86-
::C) where {C <: Union{HamiltonMutableCache, VerletLeapfrogCache,
85+
function verify_f2(f::F, res, p, q, pa, t, ::Any,
86+
::C) where {F, C <: Union{HamiltonMutableCache, VerletLeapfrogCache,
8787
LeapfrogDriftKickDriftCache}}
8888
f(res, p, q, pa, t)
8989
end
9090

91-
function verify_f2(f, p, q, pa, t, integrator, ::C) where {C <: VelocityVerletConstantCache}
91+
function verify_f2(f::F, p, q, pa, t, integrator, ::C) where {F, C <: VelocityVerletConstantCache}
9292
res = f(p, q, pa, t)
9393
res == p ? p : throwex(integrator)
9494
end
95-
function verify_f2(f, res, p, q, pa, t, integrator, ::C) where {C <: VelocityVerletCache}
95+
function verify_f2(f::F, res, p, q, pa, t, integrator, ::C) where {F, C <: VelocityVerletCache}
9696
f(res, p, q, pa, t)
9797
res == p ? res : throwex(integrator)
9898
end

0 commit comments

Comments
 (0)