Skip to content

Commit 6c9fbac

Browse files
claudeChrisRackauckas
authored andcommitted
Specialize functions on f parameter for trim compatibility
Similar to #684 and SciML/OrdinaryDiffEq.jl#2854, this PR adds type specialization for the f parameter in several functions across NonlinearSolve.jl to improve compatibility with --trim and reduce dynamic dispatch. Functions specialized: - _make_py_residual in NonlinearSolveSciPy.jl - _make_py_scalar in NonlinearSolveSciPy.jl - dogleg_method!! in trust_region.jl - construct_jacobian in jacobian_handling.jl (both general and AutoEnzyme versions) These functions either call f directly, pass f to other functions, or work with f in ways that benefit from type specialization for better compiler optimizations.
1 parent b37b31b commit 6c9fbac

File tree

3 files changed

+5
-5
lines changed

3 files changed

+5
-5
lines changed

lib/NonlinearSolveHomotopyContinuation/src/jacobian_handling.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ and `p` the parameter object.
111111
112112
The returned function must have the signature required by `HomotopySystemWrapper`.
113113
"""
114-
function construct_jacobian(f, autodiff, variant, u0, p)
114+
function construct_jacobian(f::F, autodiff, variant, u0, p) where F
115115
if variant == Scalar
116116
tmp = reinterpret(Float64, Vector{ComplexF64}(undef, 1))
117117
else
@@ -182,7 +182,7 @@ end
182182
183183
Construct an `EnzymeJacobian` function.
184184
"""
185-
function construct_jacobian(f, autodiff::AutoEnzyme, variant, u0, p)
185+
function construct_jacobian(f::F, autodiff::AutoEnzyme, variant, u0, p) where F
186186
if variant == Scalar
187187
prep = DI.prepare_derivative(f, autodiff, u0, DI.Constant(p), strict = Val(false))
188188
else

lib/NonlinearSolveSciPy/src/NonlinearSolveSciPy.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ end
9494
"""
9595
Internal: wrap a Julia residual function into a Python callable
9696
"""
97-
function _make_py_residual(f, p)
97+
function _make_py_residual(f::F, p) where F
9898
return pyfunc(x_py -> begin
9999
x = Vector{Float64}(x_py)
100100
r = f(x, p)
@@ -105,7 +105,7 @@ end
105105
"""
106106
Internal: wrap a Julia scalar function into a Python callable
107107
"""
108-
function _make_py_scalar(f, p)
108+
function _make_py_scalar(f::F, p) where F
109109
return pyfunc(x_py -> begin
110110
x = Float64(x_py)
111111
return f(x, p)

lib/SimpleNonlinearSolve/src/trust_region.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ function SciMLBase.__solve(
193193
return SciMLBase.build_solution(prob, alg, x, fx; retcode = ReturnCode.MaxIters)
194194
end
195195

196-
function dogleg_method!!(cache, J, f, g, Δ)
196+
function dogleg_method!!(cache, J, f::F, g, Δ) where F
197197
(; δsd, δN_δsd, δN) = cache
198198

199199
# Compute the Newton step

0 commit comments

Comments
 (0)