Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 31 additions & 12 deletions src/scimlfunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4972,13 +4972,16 @@ function __has_initializeprob(f)
has_initialization_data(f) && hasfield(typeof(f.initialization_data), :initializeprob)
end
function __has_update_initializeprob!(f)
has_initialization_data(f) && hasfield(typeof(f.initialization_data), :update_initializeprob!)
has_initialization_data(f) &&
hasfield(typeof(f.initialization_data), :update_initializeprob!)
end
function __has_initializeprobmap(f)
has_initialization_data(f) && hasfield(typeof(f.initialization_data), :initializeprobmap)
has_initialization_data(f) &&
hasfield(typeof(f.initialization_data), :initializeprobmap)
end
function __has_initializeprobpmap(f)
has_initialization_data(f) && hasfield(typeof(f.initialization_data), :initializeprobpmap)
has_initialization_data(f) &&
hasfield(typeof(f.initialization_data), :initializeprobpmap)
end
__has_initialization_data(f) = hasfield(typeof(f), :initialization_data)
__has_polynomialize(f) = hasfield(typeof(f), :polynomialize)
Expand Down Expand Up @@ -5046,15 +5049,31 @@ function has_observed(f::AbstractSciMLFunction)
end
has_colorvec(f::AbstractSciMLFunction) = __has_colorvec(f) && f.colorvec !== nothing

# TODO: find an appropriate way to check `has_*`
has_jac(f::Union{SplitFunction, SplitSDEFunction}) = has_jac(f.f1)
has_jvp(f::Union{SplitFunction, SplitSDEFunction}) = has_jvp(f.f1)
has_vjp(f::Union{SplitFunction, SplitSDEFunction}) = has_vjp(f.f1)
has_tgrad(f::Union{SplitFunction, SplitSDEFunction}) = has_tgrad(f.f1)
has_Wfact(f::Union{SplitFunction, SplitSDEFunction}) = has_Wfact(f.f1)
has_Wfact_t(f::Union{SplitFunction, SplitSDEFunction}) = has_Wfact_t(f.f1)
has_paramjac(f::Union{SplitFunction, SplitSDEFunction}) = has_paramjac(f.f1)
has_colorvec(f::Union{SplitFunction, SplitSDEFunction}) = has_colorvec(f.f1)
# Check the SplitFunction's own fields first before delegating to f.f1
function has_jac(f::Union{SplitFunction, SplitSDEFunction})
(__has_jac(f) && f.jac !== nothing) || has_jac(f.f1)
end
function has_jvp(f::Union{SplitFunction, SplitSDEFunction})
(__has_jvp(f) && f.jvp !== nothing) || has_jvp(f.f1)
end
function has_vjp(f::Union{SplitFunction, SplitSDEFunction})
(__has_vjp(f) && f.vjp !== nothing) || has_vjp(f.f1)
end
function has_tgrad(f::Union{SplitFunction, SplitSDEFunction})
(__has_tgrad(f) && f.tgrad !== nothing) || has_tgrad(f.f1)
end
function has_Wfact(f::Union{SplitFunction, SplitSDEFunction})
(__has_Wfact(f) && f.Wfact !== nothing) || has_Wfact(f.f1)
end
function has_Wfact_t(f::Union{SplitFunction, SplitSDEFunction})
(__has_Wfact_t(f) && f.Wfact_t !== nothing) || has_Wfact_t(f.f1)
end
function has_paramjac(f::Union{SplitFunction, SplitSDEFunction})
(__has_paramjac(f) && f.paramjac !== nothing) || has_paramjac(f.f1)
end
function has_colorvec(f::Union{SplitFunction, SplitSDEFunction})
(__has_colorvec(f) && f.colorvec !== nothing) || has_colorvec(f.f1)
end

has_jac(f::Union{DynamicalODEFunction, DynamicalDDEFunction}) = has_jac(f.f1)
has_jvp(f::Union{DynamicalODEFunction, DynamicalDDEFunction}) = has_jvp(f.f1)
Expand Down
Loading