Skip to content

Commit 9fba92d

Browse files
Fix SplitFunction to use user-provided jvp function
When a user provides a jvp function to SplitFunction, it should be used instead of falling back to automatic differentiation. The has_jvp check was only checking f.f1 instead of first checking the SplitFunction's own jvp field. This fix checks the SplitFunction's own fields (jvp, vjp, jac, etc.) first before delegating to f.f1, which allows user-provided analytical derivatives to be used properly with matrix-free Krylov solvers. Fixes SciML/DifferentialEquations.jl#1109 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
1 parent 92d8460 commit 9fba92d

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

src/scimlfunctions.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5046,15 +5046,15 @@ function has_observed(f::AbstractSciMLFunction)
50465046
end
50475047
has_colorvec(f::AbstractSciMLFunction) = __has_colorvec(f) && f.colorvec !== nothing
50485048

5049-
# TODO: find an appropriate way to check `has_*`
5050-
has_jac(f::Union{SplitFunction, SplitSDEFunction}) = has_jac(f.f1)
5051-
has_jvp(f::Union{SplitFunction, SplitSDEFunction}) = has_jvp(f.f1)
5052-
has_vjp(f::Union{SplitFunction, SplitSDEFunction}) = has_vjp(f.f1)
5053-
has_tgrad(f::Union{SplitFunction, SplitSDEFunction}) = has_tgrad(f.f1)
5054-
has_Wfact(f::Union{SplitFunction, SplitSDEFunction}) = has_Wfact(f.f1)
5055-
has_Wfact_t(f::Union{SplitFunction, SplitSDEFunction}) = has_Wfact_t(f.f1)
5056-
has_paramjac(f::Union{SplitFunction, SplitSDEFunction}) = has_paramjac(f.f1)
5057-
has_colorvec(f::Union{SplitFunction, SplitSDEFunction}) = has_colorvec(f.f1)
5049+
# Check the SplitFunction's own fields first before delegating to f.f1
5050+
has_jac(f::Union{SplitFunction, SplitSDEFunction}) = (__has_jac(f) && f.jac !== nothing) || has_jac(f.f1)
5051+
has_jvp(f::Union{SplitFunction, SplitSDEFunction}) = (__has_jvp(f) && f.jvp !== nothing) || has_jvp(f.f1)
5052+
has_vjp(f::Union{SplitFunction, SplitSDEFunction}) = (__has_vjp(f) && f.vjp !== nothing) || has_vjp(f.f1)
5053+
has_tgrad(f::Union{SplitFunction, SplitSDEFunction}) = (__has_tgrad(f) && f.tgrad !== nothing) || has_tgrad(f.f1)
5054+
has_Wfact(f::Union{SplitFunction, SplitSDEFunction}) = (__has_Wfact(f) && f.Wfact !== nothing) || has_Wfact(f.f1)
5055+
has_Wfact_t(f::Union{SplitFunction, SplitSDEFunction}) = (__has_Wfact_t(f) && f.Wfact_t !== nothing) || has_Wfact_t(f.f1)
5056+
has_paramjac(f::Union{SplitFunction, SplitSDEFunction}) = (__has_paramjac(f) && f.paramjac !== nothing) || has_paramjac(f.f1)
5057+
has_colorvec(f::Union{SplitFunction, SplitSDEFunction}) = (__has_colorvec(f) && f.colorvec !== nothing) || has_colorvec(f.f1)
50585058

50595059
has_jac(f::Union{DynamicalODEFunction, DynamicalDDEFunction}) = has_jac(f.f1)
50605060
has_jvp(f::Union{DynamicalODEFunction, DynamicalDDEFunction}) = has_jvp(f.f1)

0 commit comments

Comments
 (0)