From df3a3f6e24c57c6c6923d8f01c60db6b80d2eec5 Mon Sep 17 00:00:00 2001 From: Claude Date: Thu, 21 Aug 2025 08:02:51 -0400 Subject: [PATCH] Specialize more functions on f parameter for trim compatibility Similar to #2854, this PR adds type specialization for the f parameter in several functions across OrdinaryDiffEq.jl to improve compatibility with --trim and reduce dynamic dispatch. Functions specialized: - jacobian scalar fallback in derivative_wrappers.jl - sparsity_colorvec in derivative_wrappers.jl - WOperator constructor in derivative_utils.jl - islinearfunction in derivative_utils.jl - _compute_rhs in newton.jl - DAEResidualJacobianWrapper constructor in utils.jl - Interpolation functions in generic_dense.jl - verify_f2 functions in symplectic_perform_step.jl These functions all either call f directly, access fields of f, pass f to other functions, or perform type checks on f, so specialization enables better compiler optimizations. --- .../src/dense/generic_dense.jl | 20 +++++++++---------- .../src/derivative_utils.jl | 4 ++-- .../src/derivative_wrappers.jl | 4 ++-- .../src/newton.jl | 2 +- lib/OrdinaryDiffEqNonlinearSolve/src/utils.jl | 2 +- .../src/symplectic_perform_step.jl | 12 +++++------ 6 files changed, 22 insertions(+), 22 deletions(-) diff --git a/lib/OrdinaryDiffEqCore/src/dense/generic_dense.jl b/lib/OrdinaryDiffEqCore/src/dense/generic_dense.jl index 4cda3adbc5..7a6bfb5fbf 100644 --- a/lib/OrdinaryDiffEqCore/src/dense/generic_dense.jl +++ b/lib/OrdinaryDiffEqCore/src/dense/generic_dense.jl @@ -477,17 +477,17 @@ end return expr end -function _evaluate_interpolant(f, Θ, dt, timeseries, i₋, i₊, +function _evaluate_interpolant(f::F, Θ, dt, timeseries, i₋, i₊, cache, idxs, - deriv, ks, ts, p, differential_vars) + deriv, ks, ts, p, differential_vars) where F _ode_addsteps!(ks[i₊], ts[i₋], timeseries[i₋], timeseries[i₊], dt, f, p, cache) # update the kcurrent return ode_interpolant(Θ, dt, timeseries[i₋], timeseries[i₊], ks[i₊], cache, idxs, deriv, differential_vars) end -function evaluate_composite_cache(f, Θ, dt, timeseries, i₋, i₊, +function evaluate_composite_cache(f::F, Θ, dt, timeseries, i₋, i₊, caches::Tuple{C1, C2, Vararg}, idxs, - deriv, ks, ts, p, cacheid, differential_vars) where {C1, C2} + deriv, ks, ts, p, cacheid, differential_vars) where {F, C1, C2} if (cacheid -= 1) != 0 return evaluate_composite_cache(f, Θ, dt, timeseries, i₋, i₊, Base.tail(caches), idxs, @@ -497,16 +497,16 @@ function evaluate_composite_cache(f, Θ, dt, timeseries, i₋, i₊, first(caches), idxs, deriv, ks, ts, p, differential_vars) end -function evaluate_composite_cache(f, Θ, dt, timeseries, i₋, i₊, +function evaluate_composite_cache(f::F, Θ, dt, timeseries, i₋, i₊, caches::Tuple{C}, idxs, - deriv, ks, ts, p, _, differential_vars) where {C} + deriv, ks, ts, p, _, differential_vars) where {F, C} _evaluate_interpolant(f, Θ, dt, timeseries, i₋, i₊, only(caches), idxs, deriv, ks, ts, p, differential_vars) end -function evaluate_default_cache(f, Θ, dt, timeseries, i₋, i₊, - cache::DefaultCache, idxs, deriv, ks, ts, p, cacheid, differential_vars) +function evaluate_default_cache(f::F, Θ, dt, timeseries, i₋, i₊, + cache::DefaultCache, idxs, deriv, ks, ts, p, cacheid, differential_vars) where F if cacheid == 1 return _evaluate_interpolant(f, Θ, dt, timeseries, i₋, i₊, cache.cache1, idxs, deriv, ks, ts, p, differential_vars) @@ -528,8 +528,8 @@ function evaluate_default_cache(f, Θ, dt, timeseries, i₋, i₊, end end -function evaluate_interpolant(f, Θ, dt, timeseries, i₋, i₊, cache, idxs, - deriv, ks, ts, id, p, differential_vars) +function evaluate_interpolant(f::F, Θ, dt, timeseries, i₋, i₊, cache, idxs, + deriv, ks, ts, id, p, differential_vars) where F if isdiscretecache(cache) return ode_interpolant(Θ, dt, timeseries[i₋], timeseries[i₊], 0, cache, idxs, deriv, differential_vars) diff --git a/lib/OrdinaryDiffEqDifferentiation/src/derivative_utils.jl b/lib/OrdinaryDiffEqDifferentiation/src/derivative_utils.jl index ab47d3aeef..b163761929 100644 --- a/lib/OrdinaryDiffEqDifferentiation/src/derivative_utils.jl +++ b/lib/OrdinaryDiffEqDifferentiation/src/derivative_utils.jl @@ -324,7 +324,7 @@ mutable struct WOperator{IIP, T, jacvec) end end -function WOperator{IIP}(f, u, gamma) where {IIP} +function WOperator{IIP}(f::F, u, gamma) where {IIP, F} if isa(f, Union{SplitFunction, DynamicalODEFunction}) error("WOperator does not support $(typeof(f)) yet") end @@ -440,7 +440,7 @@ islinearfunction(integrator) = islinearfunction(integrator.f, integrator.alg) return the tuple `(is_linear_wrt_odealg, islinearodefunction)`. """ -function islinearfunction(f, alg)::Tuple{Bool, Bool} +function islinearfunction(f::F, alg)::Tuple{Bool, Bool} where F isode = f isa ODEFunction && islinear(f.f) islin = isode || (issplit(alg) && f isa SplitFunction && islinear(f.f1.f)) return islin, isode diff --git a/lib/OrdinaryDiffEqDifferentiation/src/derivative_wrappers.jl b/lib/OrdinaryDiffEqDifferentiation/src/derivative_wrappers.jl index c5dbd16eb6..853a26771b 100644 --- a/lib/OrdinaryDiffEqDifferentiation/src/derivative_wrappers.jl +++ b/lib/OrdinaryDiffEqDifferentiation/src/derivative_wrappers.jl @@ -136,7 +136,7 @@ function jacobian(f::F, x::AbstractArray{<:Number}, integrator) where F end # fallback for scalar x, is needed for calc_J to work -function jacobian(f, x, integrator) +function jacobian(f::F, x, integrator) where F alg = unwrap_alg(integrator, true) dense = ADTypes.dense_ad(alg_autodiff(alg)) @@ -393,7 +393,7 @@ function build_grad_config(alg, f::F1, tf::F2, du1, t) where {F1, F2} end end -function sparsity_colorvec(f, x) +function sparsity_colorvec(f::F, x) where F sparsity = f.sparsity if is_sparse_csc(sparsity) diff --git a/lib/OrdinaryDiffEqNonlinearSolve/src/newton.jl b/lib/OrdinaryDiffEqNonlinearSolve/src/newton.jl index 00c633ba8d..3e593b7766 100644 --- a/lib/OrdinaryDiffEqNonlinearSolve/src/newton.jl +++ b/lib/OrdinaryDiffEqNonlinearSolve/src/newton.jl @@ -330,7 +330,7 @@ function compute_ustep!(ustep, tmp, γ, z, method) ustep end -function _compute_rhs(tmp, γ, α, tstep, invγdt, method::MethodType, p, dt, f, z) +function _compute_rhs(tmp, γ, α, tstep, invγdt, method::MethodType, p, dt, f::F, z) where F mass_matrix = f.mass_matrix ustep = compute_ustep(tmp, γ, z, method) if method === COEFFICIENT_MULTISTEP diff --git a/lib/OrdinaryDiffEqNonlinearSolve/src/utils.jl b/lib/OrdinaryDiffEqNonlinearSolve/src/utils.jl index 0a9cfe11e2..49a0db411d 100644 --- a/lib/OrdinaryDiffEqNonlinearSolve/src/utils.jl +++ b/lib/OrdinaryDiffEqNonlinearSolve/src/utils.jl @@ -71,7 +71,7 @@ mutable struct DAEResidualJacobianWrapper{isAD, F, pType, duType, uType, alphaTy tmp::tmpType uprev::uprevType t::tType - function DAEResidualJacobianWrapper(alg, f, p, α, invγdt, tmp, uprev, t) + function DAEResidualJacobianWrapper(alg, f::F, p, α, invγdt, tmp, uprev, t) where F ad = ADTypes.dense_ad(alg_autodiff(alg)) isautodiff = ad isa AutoForwardDiff if isautodiff diff --git a/lib/OrdinaryDiffEqSymplecticRK/src/symplectic_perform_step.jl b/lib/OrdinaryDiffEqSymplecticRK/src/symplectic_perform_step.jl index fcbf00b2f1..d70d3ce517 100644 --- a/lib/OrdinaryDiffEqSymplecticRK/src/symplectic_perform_step.jl +++ b/lib/OrdinaryDiffEqSymplecticRK/src/symplectic_perform_step.jl @@ -77,22 +77,22 @@ end # f.f2(p, q, pa, t) = p which is the Newton/Lagrange equations # If called with different functions (which are possible in the Hamiltonian case) # an exception is thrown to avoid silently calculate wrong results. -function verify_f2(f, p, q, pa, t, ::Any, - ::C) where {C <: Union{HamiltonConstantCache, VerletLeapfrogConstantCache, +function verify_f2(f::F, p, q, pa, t, ::Any, + ::C) where {F, C <: Union{HamiltonConstantCache, VerletLeapfrogConstantCache, LeapfrogDriftKickDriftConstantCache}} f(p, q, pa, t) end -function verify_f2(f, res, p, q, pa, t, ::Any, - ::C) where {C <: Union{HamiltonMutableCache, VerletLeapfrogCache, +function verify_f2(f::F, res, p, q, pa, t, ::Any, + ::C) where {F, C <: Union{HamiltonMutableCache, VerletLeapfrogCache, LeapfrogDriftKickDriftCache}} f(res, p, q, pa, t) end -function verify_f2(f, p, q, pa, t, integrator, ::C) where {C <: VelocityVerletConstantCache} +function verify_f2(f::F, p, q, pa, t, integrator, ::C) where {F, C <: VelocityVerletConstantCache} res = f(p, q, pa, t) res == p ? p : throwex(integrator) end -function verify_f2(f, res, p, q, pa, t, integrator, ::C) where {C <: VelocityVerletCache} +function verify_f2(f::F, res, p, q, pa, t, integrator, ::C) where {F, C <: VelocityVerletCache} f(res, p, q, pa, t) res == p ? res : throwex(integrator) end