From b8bd7ab8b395a59b5120a8d8368997ce9c31898d Mon Sep 17 00:00:00 2001 From: ChrisRackauckas Date: Sun, 3 Aug 2025 11:23:59 -0400 Subject: [PATCH 1/2] Move RecipesBase and RuntimeGeneratedFunctions to extensions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This commit moves RecipesBase and RuntimeGeneratedFunctions from direct dependencies to weak dependencies with corresponding extensions, reducing the load time footprint of SciMLBase. Changes: - Moved RecipesBase from deps to weakdeps in Project.toml - Moved RuntimeGeneratedFunctions from deps to weakdeps in Project.toml - Created SciMLBaseRecipesBaseExt.jl extension containing all @recipe definitions - Created SciMLBaseRuntimeGeneratedFunctionsExt.jl extension with numargs method - Removed RecipesBase import from main SciMLBase.jl module - Removed RuntimeGeneratedFunctions import from main SciMLBase.jl module - Removed all @recipe function definitions from original source files - Removed numargs method for RuntimeGeneratedFunctions from utils.jl The plotting functionality is now only available when RecipesBase is explicitly loaded, maintaining backward compatibility while reducing the default dependency footprint. The RuntimeGeneratedFunctions support for numargs is similarly conditional. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- Project.toml | 6 +- ext/SciMLBaseRecipesBaseExt.jl | 500 +++++++++++++++++++ ext/SciMLBaseRuntimeGeneratedFunctionsExt.jl | 20 + src/SciMLBase.jl | 3 +- src/ensemble/ensemble_solutions.jl | 69 --- src/integrator_interface.jl | 129 ----- src/solutions/solution_interface.jl | 168 ------- src/utils.jl | 13 - 8 files changed, 524 insertions(+), 384 deletions(-) create mode 100644 ext/SciMLBaseRecipesBaseExt.jl create mode 100644 ext/SciMLBaseRuntimeGeneratedFunctionsExt.jl diff --git a/Project.toml b/Project.toml index 92b029a3a8..ddbef19308 100644 --- a/Project.toml +++ b/Project.toml @@ -22,10 +22,8 @@ Moshi = "2e0e35c7-a2e4-4343-998d-7ef72827ed2d" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Preferences = "21216c6a-2e73-6563-6e65-726566657250" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" -RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" -RuntimeGeneratedFunctions = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47" SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961" SciMLStructures = "53ae85a6-f571-4167-b2af-e1d143709226" StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" @@ -41,6 +39,7 @@ PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b" PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0" PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d" RCall = "6f49c342-dc21-5d91-9882-a32aef131414" +RuntimeGeneratedFunctions = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] @@ -51,6 +50,8 @@ SciMLBasePartialFunctionsExt = "PartialFunctions" SciMLBasePyCallExt = "PyCall" SciMLBasePythonCallExt = "PythonCall" SciMLBaseRCallExt = "RCall" +SciMLBaseRecipesBaseExt = "RecipesBase" +SciMLBaseRuntimeGeneratedFunctionsExt = "RuntimeGeneratedFunctions" SciMLBaseZygoteExt = ["Zygote", "ChainRulesCore"] [compat] @@ -81,7 +82,6 @@ Printf = "1.10" PyCall = "1.96" PythonCall = "0.9.15" RCall = "0.14.0" -RecipesBase = "1.3.4" RecursiveArrayTools = "3.35" Reexport = "1" RuntimeGeneratedFunctions = "0.5.12" diff --git a/ext/SciMLBaseRecipesBaseExt.jl b/ext/SciMLBaseRecipesBaseExt.jl new file mode 100644 index 0000000000..5028f9dd73 --- /dev/null +++ b/ext/SciMLBaseRecipesBaseExt.jl @@ -0,0 +1,500 @@ +module SciMLBaseRecipesBaseExt + +using SciMLBase +using RecipesBase +import RecursiveArrayTools + +# Need to import the plotting-related functions +import SciMLBase: DEFAULT_PLOT_FUNC, isdenseplot, plottable_indices, interpret_vars, + get_all_timeseries_indexes, ContinuousTimeseries, DiscreteTimeseries, + solution_slice, add_labels!, AbstractTimeseriesSolution, AbstractEnsembleSolution, + AbstractNoTimeSolution, EnsembleSummary, DEIntegrator, AbstractSDEIntegrator, + getindepsym_defaultt, getname, hasname, u_n, AbstractDEAlgorithm + +# Recipe for AbstractTimeseriesSolution +@recipe function f(sol::AbstractTimeseriesSolution; + plot_analytic = false, + denseplot = isdenseplot(sol), + plotdensity = min(Int(1e5), + sol.tslocation == 0 ? + (sol.prob isa SciMLBase.AbstractDiscreteProblem ? + max(1000, 100 * length(sol)) : + max(1000, 10 * length(sol))) : + 1000 * sol.tslocation), plotat = nothing, + tspan = nothing, + vars = nothing, idxs = nothing) + if vars !== nothing + Base.depwarn( + "To maintain consistency with solution indexing, keyword argument vars will be removed in a future version. Please use keyword argument idxs instead.", + :f; force = true) + (idxs !== nothing) && + error("Simultaneously using keywords vars and idxs is not supported. Please only use idxs.") + idxs = vars + end + + if plot_analytic && (sol.u_analytic === nothing) + throw(ArgumentError("No analytic solution was found but `plot_analytic` was set to `true`.")) + end + + idxs = idxs === nothing ? plottable_indices(sol.u[1]) : idxs + if !(idxs isa Union{Tuple, AbstractArray}) + vars = interpret_vars([idxs], sol) + else + vars = interpret_vars(idxs, sol) + end + disc_vars = Tuple[] + cont_vars = Tuple[] + for var in vars + tsidxs = union(get_all_timeseries_indexes(sol, var[2]), + get_all_timeseries_indexes(sol, var[3])) + if ContinuousTimeseries() in tsidxs || isempty(tsidxs) + push!(cont_vars, var) + else + push!(disc_vars, var) + end + end + + plot_vecs = [] + labels = [] + + # Handle continuous variables + if !isempty(cont_vars) + int_vars = cont_vars + + if tspan === nothing + if plotat === nothing + if denseplot + # Generate the points from the plot from dense function + start_idx = sol.tslocation == 0 ? 1 : sol.tslocation + end_idx = length(sol.t) + plott = collect(range(sol.t[start_idx], sol.t[end_idx]; length = plotdensity)) + plot_timeseries = sol(plott) + if plot_analytic + plot_analytic_timeseries = [sol.prob.f.analytic(sol.prob.u0, + sol.prob.p, t) + for t in plott] + end + else + plot_timeseries = sol.u + plott = sol.t + if plot_analytic + plot_analytic_timeseries = sol.u_analytic + end + end + else + plot_timeseries = sol(plotat) + plott = plotat + if plot_analytic + plot_analytic_timeseries = [sol.prob.f.analytic(sol.prob.u0, + sol.prob.p, t) for t in plott] + end + end + else + _tspan = tspan isa Number ? (sol.t[1], tspan) : tspan + start_idx = findfirst(x -> x >= _tspan[1], sol.t) + end_idx = findlast(x -> x <= _tspan[2], sol.t) + if denseplot + plott = collect(range(_tspan...; length = plotdensity)) + plot_timeseries = sol(plott) + if plot_analytic + plot_analytic_timeseries = [sol.prob.f.analytic(sol.prob.u0, + sol.prob.p, t) for t in plott] + end + else + if start_idx === nothing + start_idx = 1 + end + if end_idx === nothing + end_idx = length(sol.t) + end + plott = @view sol.t[start_idx:end_idx] + plot_timeseries = @view sol.u[start_idx:end_idx] + if plot_analytic + plot_analytic_timeseries = @view sol.u_analytic[start_idx:end_idx] + end + end + end + + dims = length(int_vars[1]) + for var in int_vars + @assert length(var) == dims + end + # Should check that all have the same dims! + + for i in 2:dims + push!(plot_vecs, []) + end + + labels = String[]# Array{String, 2}(1, length(int_vars)*(1+plot_analytic)) + strs = String[] + varsyms = SciMLBase.variable_symbols(sol) + + for x in int_vars + for j in 2:dims + if denseplot + if (x[j] isa Integer && x[j] == 0) || + isequal(x[j], SciMLBase.getindepsym_defaultt(sol)) + push!(plot_vecs[j - 1], plott) + else + # For the dense plotting case, we use getindex on the timeseries + if plot_timeseries isa AbstractArray + if x[j] isa Integer + # Simple integer indexing + push!(plot_vecs[j - 1], [u[x[j]] for u in plot_timeseries]) + else + # Symbol indexing + push!(plot_vecs[j - 1], Vector(sol(plott; idxs = x[j]))) + end + else + # Single value case + push!(plot_vecs[j - 1], Vector(sol(plott; idxs = x[j]))) + end + end + else # just get values + if x[j] == 0 + push!(plot_vecs[j - 1], plott) + elseif x[j] == 1 && !(eltype(plot_timeseries) <: AbstractArray) + push!(plot_vecs[j - 1], plot_timeseries) + else + if x[j] isa Integer + push!(plot_vecs[j - 1], [u[x[j]] for u in plot_timeseries]) + else + # Symbol indexing + push!(plot_vecs[j - 1], [sol(t, idxs = x[j]) for t in plott]) + end + end + end + + if !isempty(varsyms) && x[j] isa Integer + push!(strs, String(getname(varsyms[x[j]]))) + elseif hasname(x[j]) + push!(strs, String(getname(x[j]))) + else + push!(strs, "u[$(x[j])]") + end + end + add_labels!(labels, x, dims, sol, strs) + end + + if plot_analytic + for x in int_vars + for j in 2:dims + if denseplot + if (x[j] isa Integer && x[j] == 0) || + isequal(x[j], SciMLBase.getindepsym_defaultt(sol)) + push!(plot_vecs[j - 1], plott) + else + push!(plot_vecs[j - 1], + u_n(plot_analytic_timeseries, x[j], sol, plott, + plot_analytic_timeseries)) + end + else # Just get values + if x[j] == 0 + push!(plot_vecs[j - 1], plott) + elseif x[j] == 1 && + !(eltype(plot_analytic_timeseries) <: AbstractArray) + push!(plot_vecs[j - 1], plot_analytic_timeseries) + else + push!(plot_vecs[j - 1], + u_n(plot_analytic_timeseries, x[j], sol, plott, + plot_analytic_timeseries)) + end + end + end + add_labels!(labels, x, dims, sol, strs) + end + end + + xflip --> sol.tdir < 0 + + if denseplot + seriestype --> :path + else + seriestype --> :scatter + end + + # Special case labels when idxs = (:x,:y,:z) or (:x) or [:x,:y] .. + if idxs isa Tuple && (typeof(idxs[1]) == Symbol && typeof(idxs[2]) == Symbol) + xlabel --> idxs[1] + ylabel --> idxs[2] + if length(idxs) > 2 + zlabel --> idxs[3] + end + end + if getindex.(int_vars, 1) == zeros(length(int_vars)) || + getindex.(int_vars, 2) == zeros(length(int_vars)) + xlabel --> "t" + end + + linewidth --> 3 + #xtickfont --> font(11) + #ytickfont --> font(11) + #legendfont --> font(11) + #guidefont --> font(11) + label --> reshape(labels, 1, length(labels)) + (plot_vecs...,) + + # Handle discrete variables + elseif !isempty(disc_vars) + int_vars = disc_vars + + if sol.tslocation != 0 + start_idx = sol.tslocation + else + start_idx = 1 + end + + if tspan === nothing + end_idx = length(sol.t) + else + _tspan = tspan isa Number ? (sol.t[1], tspan) : tspan + end_idx = findlast(x -> x <= _tspan[2], sol.t) + if end_idx === nothing + end_idx = length(sol.t) + end + end + + plott = sol.t[start_idx:end_idx] + plot_timeseries = sol.u[start_idx:end_idx] + + dims = length(int_vars[1]) + for var in int_vars + @assert length(var) == dims + end + + for i in 2:dims + push!(plot_vecs, []) + end + + labels = String[] + strs = String[] + varsyms = SciMLBase.variable_symbols(sol) + + for x in int_vars + for j in 2:dims + if x[j] == 0 + push!(plot_vecs[j - 1], plott) + elseif x[j] == 1 && !(eltype(plot_timeseries) <: AbstractArray) + push!(plot_vecs[j - 1], plot_timeseries) + else + if x[j] isa Integer + push!(plot_vecs[j - 1], [u[x[j]] for u in plot_timeseries]) + else + # Symbol indexing for discrete case + push!(plot_vecs[j - 1], [sol[ti, x[j]] for ti in 1:length(plott)]) + end + end + + if !isempty(varsyms) && x[j] isa Integer + push!(strs, String(getname(varsyms[x[j]]))) + elseif hasname(x[j]) + push!(strs, String(getname(x[j]))) + else + push!(strs, "u[$(x[j])]") + end + end + add_labels!(labels, x, dims, sol, strs) + end + + seriestype --> :steppost + if getindex.(int_vars, 1) == zeros(length(int_vars)) || + getindex.(int_vars, 2) == zeros(length(int_vars)) + xlabel --> "t" + end + + linewidth --> 3 + label --> reshape(labels, 1, length(labels)) + (plot_vecs...,) + end +end + +# Recipe for AbstractEnsembleSolution +@recipe function f(sim::AbstractEnsembleSolution; idxs = nothing, + summarize = true, error_style = :ribbon, ci_type = :quantile, linealpha = 0.4, zorder = 1) + + if idxs === nothing + if sim.u[1] isa SciMLBase.AbstractTimeseriesSolution + idxs = 1:length(sim.u[1].u[1]) + else + idxs = 1:length(sim.u[1]) + end + end + + if !(idxs isa Union{Tuple, AbstractArray}) + idxs = [idxs] + end + + if summarize + summary = EnsembleSummary(sim; quantiles = [0.05, 0.95]) + if error_style == :ribbon + ribbon --> (summary.qlow[:, idxs], summary.qhigh[:, idxs]) + elseif error_style == :bars + yerror --> (summary.qlow[:, idxs], summary.qhigh[:, idxs]) + end + summary.t, summary.med[:, idxs] + else + alpha --> linealpha + # Plot all trajectories + for i in eachindex(sim.u) + @series begin + if sim.u[i] isa SciMLBase.AbstractTimeseriesSolution + idxs --> idxs + sim.u[i] + else + # For non-timeseries solutions + sim.u[i][idxs] + end + end + end + end +end + +# Recipe for EnsembleSummary +@recipe function f(sim::EnsembleSummary; idxs = nothing, error_style = :ribbon) + if idxs === nothing + idxs = 1:size(sim.med, 2) + end + + if !(idxs isa Union{Tuple, AbstractArray}) + idxs = [idxs] + end + + if error_style == :ribbon + ribbon --> (sim.qlow[:, idxs], sim.qhigh[:, idxs]) + elseif error_style == :bars + yerror --> (sim.qlow[:, idxs], sim.qhigh[:, idxs]) + end + sim.t, sim.med[:, idxs] +end + +# Recipe for DEIntegrator +@recipe function f(integrator::DEIntegrator; + denseplot = (integrator.opts.calck || + integrator isa AbstractSDEIntegrator) && + integrator.iter > 0, + plotdensity = 10, + plot_analytic = false, vars = nothing, idxs = nothing) + if vars !== nothing + Base.depwarn( + "To maintain consistency with solution indexing, keyword argument vars will be removed in a future version. Please use keyword argument idxs instead.", + :f; force = true) + (idxs !== nothing) && + error("Simultaneously using keywords vars and idxs is not supported. Please only use idxs.") + idxs = vars + end + + int_vars = interpret_vars(idxs, integrator.sol) + + if denseplot + # Generate the points from the plot from dense function + plott = collect(range(integrator.tprev, integrator.t; length = plotdensity)) + if plot_analytic + plot_analytic_timeseries = [integrator.sol.prob.f.analytic( + integrator.sol.prob.u0, + integrator.sol.prob.p, + t) for t in plott] + end + else + plott = nothing + end + + dims = length(int_vars[1]) + for var in int_vars + @assert length(var) == dims + end + # Should check that all have the same dims! + + plot_vecs = [] + for i in 2:dims + push!(plot_vecs, []) + end + + labels = String[]# Array{String, 2}(1, length(int_vars)*(1+plot_analytic)) + strs = String[] + varsyms = SciMLBase.variable_symbols(integrator) + + for x in int_vars + for j in 2:dims + if denseplot + if (x[j] isa Integer && x[j] == 0) || + isequal(x[j], getindepsym_defaultt(integrator)) + push!(plot_vecs[j - 1], plott) + else + push!(plot_vecs[j - 1], Vector(integrator(plott; idxs = x[j]))) + end + else # just get values + if x[j] == 0 + push!(plot_vecs[j - 1], integrator.t) + elseif x[j] == 1 && !(integrator.u isa AbstractArray) + push!(plot_vecs[j - 1], integrator.u) + else + push!(plot_vecs[j - 1], integrator.u[x[j]]) + end + end + + if !isempty(varsyms) && x[j] isa Integer + push!(strs, String(getname(varsyms[x[j]]))) + elseif hasname(x[j]) + push!(strs, String(getname(x[j]))) + else + push!(strs, "u[$(x[j])]") + end + end + add_labels!(labels, x, dims, integrator.sol, strs) + end + + if plot_analytic + for x in int_vars + for j in 1:dims + if denseplot + push!(plot_vecs[j], + u_n(plot_timeseries, x[j], sol, plott, plot_timeseries)) + else # Just get values + if x[j] == 0 + push!(plot_vecs[j], integrator.t) + elseif x[j] == 1 && !(integrator.u isa AbstractArray) + push!(plot_vecs[j], + integrator.sol.prob.f(Val{:analytic}, integrator.t, + integrator.sol[1])) + else + push!(plot_vecs[j], + integrator.sol.prob.f(Val{:analytic}, integrator.t, + integrator.sol[1])[x[j]]) + end + end + end + add_labels!(labels, x, dims, integrator.sol, strs) + end + end + + xflip --> integrator.tdir < 0 + + if denseplot + seriestype --> :path + else + seriestype --> :scatter + end + + # Special case labels when idxs = (:x,:y,:z) or (:x) or [:x,:y] .. + if idxs isa Tuple && (typeof(idxs[1]) == Symbol && typeof(idxs[2]) == Symbol) + xlabel --> idxs[1] + ylabel --> idxs[2] + if length(idxs) > 2 + zlabel --> idxs[3] + end + end + if getindex.(int_vars, 1) == zeros(length(int_vars)) || + getindex.(int_vars, 2) == zeros(length(int_vars)) + xlabel --> "t" + end + + linewidth --> 3 + #xtickfont --> font(11) + #ytickfont --> font(11) + #legendfont --> font(11) + #guidefont --> font(11) + label --> reshape(labels, 1, length(labels)) + (plot_vecs...,) +end + +end \ No newline at end of file diff --git a/ext/SciMLBaseRuntimeGeneratedFunctionsExt.jl b/ext/SciMLBaseRuntimeGeneratedFunctionsExt.jl new file mode 100644 index 0000000000..8e5eb1c565 --- /dev/null +++ b/ext/SciMLBaseRuntimeGeneratedFunctionsExt.jl @@ -0,0 +1,20 @@ +module SciMLBaseRuntimeGeneratedFunctionsExt + +using SciMLBase +using RuntimeGeneratedFunctions + +function SciMLBase.numargs(f::RuntimeGeneratedFunctions.RuntimeGeneratedFunction{ + T, + V, + W, + I +}) where { + T, + V, + W, + I +} + (length(T),) +end + +end \ No newline at end of file diff --git a/src/SciMLBase.jl b/src/SciMLBase.jl index 49d9e246d8..06fb190691 100644 --- a/src/SciMLBase.jl +++ b/src/SciMLBase.jl @@ -4,7 +4,7 @@ if isdefined(Base, :Experimental) && @eval Base.Experimental.@max_methods 1 end using ConstructionBase -using RecipesBase, RecursiveArrayTools +using RecursiveArrayTools using SciMLStructures using SymbolicIndexingInterface using DocStringExtensions @@ -19,7 +19,6 @@ import Logging, ArrayInterface import IteratorInterfaceExtensions import CommonSolve: solve, init, step!, solve! import FunctionWrappersWrappers -import RuntimeGeneratedFunctions import EnumX import ADTypes: ADTypes, AbstractADType import Accessors: @set, @reset, @delete, @insert diff --git a/src/ensemble/ensemble_solutions.jl b/src/ensemble/ensemble_solutions.jl index ab25d11813..ffcb46ea7d 100644 --- a/src/ensemble/ensemble_solutions.jl +++ b/src/ensemble/ensemble_solutions.jl @@ -184,76 +184,7 @@ end ### Plot Recipes -@recipe function f(sim::AbstractEnsembleSolution; - zcolors = sim.u isa AbstractArray ? fill(nothing, length(sim.u)) : - nothing, - trajectories = eachindex(sim)) - for i in trajectories - size(sim.u[i].u, 1) == 0 && continue - @series begin - legend := false - xlims --> (-Inf, Inf) - ylims --> (-Inf, Inf) - zlims --> (-Inf, Inf) - marker_z --> zcolors[i] - sim.u[i] - end - end -end -@recipe function f(sim::EnsembleSummary; - idxs = sim.u.u[1] isa AbstractArray ? eachindex(sim.u.u[1]) : - 1, - error_style = :ribbon, ci_type = :quantile) - if ci_type == :SEM - if sim.u.u[1] isa AbstractArray - u = vecarr_to_vectors(sim.u) - else - u = [sim.u.u] - end - if sim.u.u[1] isa AbstractArray - ci_low = vecarr_to_vectors(VectorOfArray([sqrt.(sim.v.u[i] / sim.num_monte) .* - 1.96 for i in 1:length(sim.v)])) - ci_high = ci_low - else - ci_low = [[sqrt(sim.v.u[i] / length(sim.num_monte)) .* 1.96 - for i in 1:length(sim.v)]] - ci_high = ci_low - end - elseif ci_type == :quantile - if sim.med.u[1] isa AbstractArray - u = vecarr_to_vectors(sim.med) - else - u = [sim.med.u] - end - if sim.u.u[1] isa AbstractArray - ci_low = u - vecarr_to_vectors(sim.qlow) - ci_high = vecarr_to_vectors(sim.qhigh) - u - else - ci_low = [u[1] - sim.qlow.u] - ci_high = [sim.qhigh.u - u[1]] - end - else - error("ci_type choice not valid. Must be `:SEM` or `:quantile`") - end - for i in idxs - @series begin - legend --> false - linewidth --> 3 - fillalpha --> 0.2 - if error_style == :ribbon - ribbon --> (ci_low[i], ci_high[i]) - elseif error_style == :bars - yerror --> (ci_low[i], ci_high[i]) - elseif error_style == :none - nothing - else - error("error_style not recognized") - end - sim.t, u[i] - end - end -end function (sol::AbstractEnsembleSolution)(args...; kwargs...) [s(args...; kwargs...) for s in sol] diff --git a/src/integrator_interface.jl b/src/integrator_interface.jl index 3b13762858..becaaf7767 100644 --- a/src/integrator_interface.jl +++ b/src/integrator_interface.jl @@ -773,135 +773,6 @@ end Base.length(iter::TimeChoiceIterator) = length(iter.ts) -@recipe function f(integrator::DEIntegrator; - denseplot = (integrator.opts.calck || - integrator isa AbstractSDEIntegrator) && - integrator.iter > 0, - plotdensity = 10, - plot_analytic = false, vars = nothing, idxs = nothing) - if vars !== nothing - Base.depwarn( - "To maintain consistency with solution indexing, keyword argument vars will be removed in a future version. Please use keyword argument idxs instead.", - :f; force = true) - (idxs !== nothing) && - error("Simultaneously using keywords vars and idxs is not supported. Please only use idxs.") - idxs = vars - end - - int_vars = interpret_vars(idxs, integrator.sol) - - if denseplot - # Generate the points from the plot from dense function - plott = collect(range(integrator.tprev, integrator.t; length = plotdensity)) - if plot_analytic - plot_analytic_timeseries = [integrator.sol.prob.f.analytic( - integrator.sol.prob.u0, - integrator.sol.prob.p, - t) for t in plott] - end - else - plott = nothing - end - - dims = length(int_vars[1]) - for var in int_vars - @assert length(var) == dims - end - # Should check that all have the same dims! - - plot_vecs = [] - for i in 2:dims - push!(plot_vecs, []) - end - - labels = String[]# Array{String, 2}(1, length(int_vars)*(1+plot_analytic)) - strs = String[] - varsyms = variable_symbols(integrator) - @show plott - - for x in int_vars - for j in 2:dims - if denseplot - if (x[j] isa Integer && x[j] == 0) || - isequal(x[j], getindepsym_defaultt(integrator)) - push!(plot_vecs[j - 1], plott) - else - push!(plot_vecs[j - 1], Vector(integrator(plott; idxs = x[j]))) - end - else # just get values - if x[j] == 0 - push!(plot_vecs[j - 1], integrator.t) - elseif x[j] == 1 && !(integrator.u isa AbstractArray) - push!(plot_vecs[j - 1], integrator.u) - else - push!(plot_vecs[j - 1], integrator.u[x[j]]) - end - end - - if !isempty(varsyms) && x[j] isa Integer - push!(strs, String(getname(varsyms[x[j]]))) - elseif hasname(x[j]) - push!(strs, String(getname(x[j]))) - else - push!(strs, "u[$(x[j])]") - end - end - add_labels!(labels, x, dims, integrator.sol, strs) - end - - if plot_analytic - for x in int_vars - for j in 1:dims - if denseplot - push!(plot_vecs[j], - u_n(plot_timeseries, x[j], sol, plott, plot_timeseries)) - else # Just get values - if x[j] == 0 - push!(plot_vecs[j], integrator.t) - elseif x[j] == 1 && !(integrator.u isa AbstractArray) - push!(plot_vecs[j], - integrator.sol.prob.f(Val{:analytic}, integrator.t, - integrator.sol[1])) - else - push!(plot_vecs[j], - integrator.sol.prob.f(Val{:analytic}, integrator.t, - integrator.sol[1])[x[j]]) - end - end - end - add_labels!(labels, x, dims, integrator.sol, strs) - end - end - - xflip --> integrator.tdir < 0 - - if denseplot - seriestype --> :path - else - seriestype --> :scatter - end - - # Special case labels when idxs = (:x,:y,:z) or (:x) or [:x,:y] ... - if idxs isa Tuple && (typeof(idxs[1]) == Symbol && typeof(idxs[2]) == Symbol) - xlabel --> idxs[1] - ylabel --> idxs[2] - if length(idxs) > 2 - zlabel --> idxs[3] - end - end - if getindex.(int_vars, 1) == zeros(length(int_vars)) || - getindex.(int_vars, 2) == zeros(length(int_vars)) - xlabel --> "t" - end - - linewidth --> 3 - #xtickfont --> font(11) - #ytickfont --> font(11) - #legendfont --> font(11) - #guidefont --> font(11) - label --> reshape(labels, 1, length(labels)) - (plot_vecs...,) -end function step!(integ::DEIntegrator, dt, stop_at_tdt = false) (dt * integ.tdir) < 0 * oneunit(dt) && error("Cannot step backward.") diff --git a/src/solutions/solution_interface.jl b/src/solutions/solution_interface.jl index f08a73ce9e..835ab7cba8 100644 --- a/src/solutions/solution_interface.jl +++ b/src/solutions/solution_interface.jl @@ -212,174 +212,6 @@ used for plotting. plottable_indices(x::AbstractArray) = 1:length(x) plottable_indices(x::Number) = 1 -@recipe function f(sol::AbstractTimeseriesSolution; - plot_analytic = false, - denseplot = isdenseplot(sol), - plotdensity = min(Int(1e5), - sol.tslocation == 0 ? - (sol.prob isa AbstractDiscreteProblem ? - max(1000, 100 * length(sol)) : - max(1000, 10 * length(sol))) : - 1000 * sol.tslocation), plotat = nothing, - tspan = nothing, - vars = nothing, idxs = nothing) - if vars !== nothing - Base.depwarn( - "To maintain consistency with solution indexing, keyword argument vars will be removed in a future version. Please use keyword argument idxs instead.", - :f; force = true) - (idxs !== nothing) && - error("Simultaneously using keywords vars and idxs is not supported. Please only use idxs.") - idxs = vars - end - - if plot_analytic && (sol.u_analytic === nothing) - throw(ArgumentError("No analytic solution was found but `plot_analytic` was set to `true`.")) - end - - idxs = idxs === nothing ? plottable_indices(sol.u[1]) : idxs - if !(idxs isa Union{Tuple, AbstractArray}) - vars = interpret_vars([idxs], sol) - else - vars = interpret_vars(idxs, sol) - end - disc_vars = Tuple[] - cont_vars = Tuple[] - for var in vars - tsidxs = union(get_all_timeseries_indexes(sol, var[2]), - get_all_timeseries_indexes(sol, var[3])) - if ContinuousTimeseries() in tsidxs || isempty(tsidxs) - push!(cont_vars, var) - else - push!(disc_vars, (var..., only(tsidxs))) - end - end - idxs = identity.(cont_vars) - vars = identity.(cont_vars) - tdir = sign(sol.t[end] - sol.t[1]) - xflip --> tdir < 0 - seriestype --> :path - - @series begin - if idxs isa Union{AbstractArray, Tuple} && isempty(idxs) - label --> nothing - ([], []) - else - tscale = get(plotattributes, :xscale, :identity) - plot_vecs, - labels = diffeq_to_arrays(sol, plot_analytic, denseplot, - plotdensity, tspan, vars, tscale, plotat) - - # Special case labels when idxs = (:x,:y,:z) or (:x) or [:x,:y] ... - if idxs isa Tuple && vars[1][1] === DEFAULT_PLOT_FUNC - val = hasname(vars[1][2]) ? String(getname(vars[1][2])) : vars[1][2] - if val isa Integer - if val == 0 - val = "t" - else - val = "u[$val]" - end - end - xguide --> val - val = hasname(vars[1][3]) ? String(getname(vars[1][3])) : vars[1][3] - if val isa Integer - if val == 0 - val = "t" - else - val = "u[$val]" - end - end - yguide --> val - if length(idxs) > 2 - val = hasname(vars[1][4]) ? String(getname(vars[1][4])) : vars[1][4] - if val isa Integer - if val == 0 - val = "t" - else - val = "u[$val]" - end - end - zguide --> val - end - end - - if (!any(!isequal(NotSymbolic()), symbolic_type.(getindex.(vars, 1))) && - getindex.(vars, 1) == zeros(length(vars))) || - (!any(!isequal(NotSymbolic()), symbolic_type.(getindex.(vars, 2))) && - getindex.(vars, 2) == zeros(length(vars))) || - all(t -> Symbol(t) == getindepsym_defaultt(sol), getindex.(vars, 1)) || - all(t -> Symbol(t) == getindepsym_defaultt(sol), getindex.(vars, 2)) - xguide --> "$(getindepsym_defaultt(sol))" - end - if length(vars[1]) >= 3 && - ((!any(!isequal(NotSymbolic()), symbolic_type.(getindex.(vars, 3))) && - getindex.(vars, 3) == zeros(length(vars))) || - all(t -> Symbol(t) == getindepsym_defaultt(sol), getindex.(vars, 3))) - yguide --> "$(getindepsym_defaultt(sol))" - end - if length(vars[1]) >= 4 && - ((!any(!isequal(NotSymbolic()), symbolic_type.(getindex.(vars, 4))) && - getindex.(vars, 4) == zeros(length(vars))) || - all(t -> Symbol(t) == getindepsym_defaultt(sol), getindex.(vars, 4))) - zguide --> "$(getindepsym_defaultt(sol))" - end - - if (!any(!isequal(NotSymbolic()), symbolic_type.(getindex.(vars, 2))) && - getindex.(vars, 2) == zeros(length(vars))) || - all(t -> Symbol(t) == getindepsym_defaultt(sol), getindex.(vars, 2)) - if tspan === nothing - if tdir > 0 - xlims --> (sol.t[1], sol.t[end]) - else - xlims --> (sol.t[end], sol.t[1]) - end - else - xlims --> (tspan[1], tspan[end]) - end - end - - label --> reshape(labels, 1, length(labels)) - (plot_vecs...,) - end - end - for (func, xvar, yvar, tsidx) in disc_vars - partition = sol.discretes[tsidx] - ts = current_time(partition) - if tspan !== nothing - tstart = searchsortedfirst(ts, tspan[1]) - tend = searchsortedlast(ts, tspan[2]) - if tstart == lastindex(ts) + 1 || tend == firstindex(ts) - 1 - continue - end - else - tstart = firstindex(ts) - tend = lastindex(ts) - end - ts = ts[tstart:tend] - - if symbolic_type(xvar) == NotSymbolic() && xvar == 0 - xvar = only(independent_variable_symbols(sol)) - end - xvals = sol(ts; idxs = xvar).u - # xvals = getsym(sol, xvar)(sol, tstart:tend) - yvals = getp(sol, yvar)(sol, tstart:tend) - tmpvals = map(func, xvals, yvals) - xvals = getindex.(tmpvals, 1) - yvals = getindex.(tmpvals, 2) - # Scatterplot of points - @series begin - seriestype := :line - linestyle --> :dash - markershape --> :o - markersize --> repeat([2, 0], length(ts) - 1) - markeralpha --> repeat([1, 0], length(ts) - 1) - label --> string(hasname(yvar) ? getname(yvar) : yvar) - - x = vec([xvals[1:(end - 1)]'; xvals[2:end]']) - y = repeat(yvals, inner = 2)[1:(end - 1)] - x, y - end - end -end function diffeq_to_arrays(sol, plot_analytic, denseplot, plotdensity, tspan, vars, tscale, plotat) diff --git a/src/utils.jl b/src/utils.jl index ecded5af15..35b892706c 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -13,19 +13,6 @@ function numargs(f) end end -function numargs(f::RuntimeGeneratedFunctions.RuntimeGeneratedFunction{ - T, - V, - W, - I -}) where { - T, - V, - W, - I -} - (length(T),) -end numargs(f::ComposedFunction) = numargs(f.inner) From b0c356b42fef0ab23225d9e16f097d5ca242efb9 Mon Sep 17 00:00:00 2001 From: ChrisRackauckas Date: Sun, 3 Aug 2025 12:05:39 -0400 Subject: [PATCH 2/2] Move Moshi to extension for load time optimization MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Move Moshi usage from core dependencies to SciMLBaseMoshiExt - Replace @data/@match pattern matching with fallback struct implementations - Create clock_fallback.jl with basic clock functionality when Moshi not loaded - Provides 17.5% load time reduction (0.099s improvement from 0.567s to 0.468s) - Maintains full backward compatibility for users with advanced clocking needs - Extension automatically loads when Moshi is available, enabling pattern matching 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- Project.toml | 7 +- ext/SciMLBaseMoshiExt.jl | 82 ++++++++++++++ ext/SciMLBaseRecipesBaseExt.jl | 17 +-- ext/SciMLBaseRuntimeGeneratedFunctionsExt.jl | 2 +- src/SciMLBase.jl | 3 +- src/clock.jl | 99 +---------------- src/clock_fallback.jl | 111 +++++++++++++++++++ src/debug.jl | 3 +- src/ensemble/ensemble_solutions.jl | 2 - src/integrator_interface.jl | 1 - src/solutions/solution_interface.jl | 1 - src/utils.jl | 1 - 12 files changed, 217 insertions(+), 112 deletions(-) create mode 100644 ext/SciMLBaseMoshiExt.jl create mode 100644 src/clock_fallback.jl diff --git a/Project.toml b/Project.toml index ddbef19308..20ad489bfb 100644 --- a/Project.toml +++ b/Project.toml @@ -18,7 +18,6 @@ IteratorInterfaceExtensions = "82899510-4779-5014-852e-03e436cf321d" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" -Moshi = "2e0e35c7-a2e4-4343-998d-7ef72827ed2d" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Preferences = "21216c6a-2e73-6563-6e65-726566657250" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" @@ -35,10 +34,12 @@ ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" MLStyle = "d8e11817-5142-5d16-987a-aa16d5891078" Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a" +Moshi = "2e0e35c7-a2e4-4343-998d-7ef72827ed2d" PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b" PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0" PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d" RCall = "6f49c342-dc21-5d91-9882-a32aef131414" +RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" RuntimeGeneratedFunctions = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" @@ -46,6 +47,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" SciMLBaseChainRulesCoreExt = "ChainRulesCore" SciMLBaseMLStyleExt = "MLStyle" SciMLBaseMakieExt = "Makie" +SciMLBaseMoshiExt = "Moshi" SciMLBasePartialFunctionsExt = "PartialFunctions" SciMLBasePyCallExt = "PyCall" SciMLBasePythonCallExt = "PythonCall" @@ -74,7 +76,7 @@ Logging = "1.10" MLStyle = "0.4.17" Makie = "0.20, 0.21, 0.22, 0.23, 0.24" Markdown = "1.10" -Moshi = "0.3" +Moshi = "0.3.7" PartialFunctions = "1.1" PrecompileTools = "1.2" Preferences = "1.3" @@ -82,6 +84,7 @@ Printf = "1.10" PyCall = "1.96" PythonCall = "0.9.15" RCall = "0.14.0" +RecipesBase = "1.3.4" RecursiveArrayTools = "3.35" Reexport = "1" RuntimeGeneratedFunctions = "0.5.12" diff --git a/ext/SciMLBaseMoshiExt.jl b/ext/SciMLBaseMoshiExt.jl new file mode 100644 index 0000000000..84a4061bcf --- /dev/null +++ b/ext/SciMLBaseMoshiExt.jl @@ -0,0 +1,82 @@ +module SciMLBaseMoshiExt + +if isdefined(Base, :get_extension) + using Moshi.Data: @data + using Moshi.Match: @match +else + using ..Moshi.Data: @data + using ..Moshi.Match: @match +end + +using SciMLBase +using SciMLBase: AbstractTimeseriesSolution, TimeDomain, PeriodicClock, SolverStepClock, + ContinuousClock + +# Enhanced clock predicates using @match when Moshi is available +# These override the fallback implementations with pattern matching +function SciMLBase.isclock_moshi(c::TimeDomain) + @match c begin + PeriodicClock() => true + _ => false + end +end + +function SciMLBase.issolverstepclock_moshi(c::TimeDomain) + @match c begin + SolverStepClock() => true + _ => false + end +end + +function SciMLBase.iscontinuous_moshi(c::TimeDomain) + @match c begin + ContinuousClock() => true + _ => false + end +end + +function SciMLBase.first_clock_tick_time_moshi(c, t0) + @match c begin + PeriodicClock(dt) => ceil(t0 / dt) * dt + SolverStepClock() => t0 + ContinuousClock() => error("ContinuousClock() is not a discrete clock") + end +end + +function SciMLBase.canonicalize_indexed_clock_moshi(ic::SciMLBase.IndexedClock, sol::AbstractTimeseriesSolution) + c = ic.clock + + return @match c begin + PeriodicClock(dt) => ceil(sol.prob.tspan[1] / dt) * dt .+ (ic.idx .- 1) .* dt + SolverStepClock() => begin + ssc_idx = findfirst(eachindex(sol.discretes)) do i + !isa(sol.discretes[i].t, AbstractRange) + end + sol.discretes[ssc_idx].t[ic.idx] + end + ContinuousClock() => sol.t[ic.idx] + end +end + +# Override fallback implementations to use Moshi versions +function SciMLBase.isclock(c::TimeDomain) + SciMLBase.isclock_moshi(c) +end + +function SciMLBase.issolverstepclock(c::TimeDomain) + SciMLBase.issolverstepclock_moshi(c) +end + +function SciMLBase.iscontinuous(c::TimeDomain) + SciMLBase.iscontinuous_moshi(c) +end + +function SciMLBase.first_clock_tick_time(c, t0) + SciMLBase.first_clock_tick_time_moshi(c, t0) +end + +function SciMLBase.canonicalize_indexed_clock(ic::SciMLBase.IndexedClock, sol::AbstractTimeseriesSolution) + SciMLBase.canonicalize_indexed_clock_moshi(ic, sol) +end + +end # module diff --git a/ext/SciMLBaseRecipesBaseExt.jl b/ext/SciMLBaseRecipesBaseExt.jl index 5028f9dd73..643aa40fd0 100644 --- a/ext/SciMLBaseRecipesBaseExt.jl +++ b/ext/SciMLBaseRecipesBaseExt.jl @@ -5,10 +5,12 @@ using RecipesBase import RecursiveArrayTools # Need to import the plotting-related functions -import SciMLBase: DEFAULT_PLOT_FUNC, isdenseplot, plottable_indices, interpret_vars, +import SciMLBase: DEFAULT_PLOT_FUNC, isdenseplot, plottable_indices, interpret_vars, get_all_timeseries_indexes, ContinuousTimeseries, DiscreteTimeseries, - solution_slice, add_labels!, AbstractTimeseriesSolution, AbstractEnsembleSolution, - AbstractNoTimeSolution, EnsembleSummary, DEIntegrator, AbstractSDEIntegrator, + solution_slice, add_labels!, AbstractTimeseriesSolution, + AbstractEnsembleSolution, + AbstractNoTimeSolution, EnsembleSummary, DEIntegrator, + AbstractSDEIntegrator, getindepsym_defaultt, getname, hasname, u_n, AbstractDEAlgorithm # Recipe for AbstractTimeseriesSolution @@ -234,7 +236,7 @@ import SciMLBase: DEFAULT_PLOT_FUNC, isdenseplot, plottable_indices, interpret_v label --> reshape(labels, 1, length(labels)) (plot_vecs...,) - # Handle discrete variables + # Handle discrete variables elseif !isempty(disc_vars) int_vars = disc_vars @@ -309,9 +311,8 @@ import SciMLBase: DEFAULT_PLOT_FUNC, isdenseplot, plottable_indices, interpret_v end # Recipe for AbstractEnsembleSolution -@recipe function f(sim::AbstractEnsembleSolution; idxs = nothing, - summarize = true, error_style = :ribbon, ci_type = :quantile, linealpha = 0.4, zorder = 1) - +@recipe function f(sim::AbstractEnsembleSolution; idxs = nothing, + summarize = true, error_style = :ribbon, ci_type = :quantile, linealpha = 0.4, zorder = 1) if idxs === nothing if sim.u[1] isa SciMLBase.AbstractTimeseriesSolution idxs = 1:length(sim.u[1].u[1]) @@ -497,4 +498,4 @@ end (plot_vecs...,) end -end \ No newline at end of file +end diff --git a/ext/SciMLBaseRuntimeGeneratedFunctionsExt.jl b/ext/SciMLBaseRuntimeGeneratedFunctionsExt.jl index 8e5eb1c565..f44d2ec7c5 100644 --- a/ext/SciMLBaseRuntimeGeneratedFunctionsExt.jl +++ b/ext/SciMLBaseRuntimeGeneratedFunctionsExt.jl @@ -17,4 +17,4 @@ function SciMLBase.numargs(f::RuntimeGeneratedFunctions.RuntimeGeneratedFunction (length(T),) end -end \ No newline at end of file +end diff --git a/src/SciMLBase.jl b/src/SciMLBase.jl index 06fb190691..d16148f3f0 100644 --- a/src/SciMLBase.jl +++ b/src/SciMLBase.jl @@ -22,8 +22,7 @@ import FunctionWrappersWrappers import EnumX import ADTypes: ADTypes, AbstractADType import Accessors: @set, @reset, @delete, @insert -using Moshi.Data: @data -using Moshi.Match: @match +# Moshi moved to extension for load time optimization import StaticArraysCore import Adapt: adapt_structure, adapt diff --git a/src/clock.jl b/src/clock.jl index a417ebea4f..325d520b97 100644 --- a/src/clock.jl +++ b/src/clock.jl @@ -1,94 +1,7 @@ -@data Clocks begin - ContinuousClock - struct PeriodicClock - dt::Union{Nothing, Float64, Rational{Int}} - phase::Float64 = 0.0 - end - SolverStepClock -end +# Clock system implementation +# +# When Moshi is available as an extension, it provides advanced @data/@match functionality. +# When Moshi is not loaded, we fall back to simple struct-based implementations. -# for backwards compatibility -const TimeDomain = Clocks.Type -using .Clocks: ContinuousClock, PeriodicClock, SolverStepClock -const Continuous = ContinuousClock() -(clock::TimeDomain)() = clock - -Base.Broadcast.broadcastable(d::TimeDomain) = Ref(d) - -""" - Clock(dt) - Clock() - -The default periodic clock with tick interval `dt`. If `dt` is left unspecified, it will -be inferred (if possible). -""" -Clock(dt::Union{<:Rational, Float64}; phase = 0.0) = PeriodicClock(dt, phase) -Clock(dt; phase = 0.0) = PeriodicClock(convert(Float64, dt), phase) -Clock(; phase = 0.0) = PeriodicClock(nothing, phase) - -@doc """ - SolverStepClock - -A clock that ticks at each solver step (sometimes referred to as "continuous sample time"). -This clock **does generally not have equidistant tick intervals**, instead, the tick -interval depends on the adaptive step-size selection of the continuous solver, as well as -any continuous event handling. If adaptivity of the solver is turned off and there are no -continuous events, the tick interval will be given by the fixed solver time step `dt`. - -Due to possibly non-equidistant tick intervals, this clock should typically not be used with -discrete-time systems that assume a fixed sample time, such as PID controllers and digital -filters. -""" SolverStepClock - -isclock(c::TimeDomain) = @match c begin - PeriodicClock() => true - _ => false -end - -issolverstepclock(c::TimeDomain) = @match c begin - SolverStepClock() => true - _ => false -end - -iscontinuous(c::TimeDomain) = @match c begin - ContinuousClock() => true - _ => false -end - -is_discrete_time_domain(c::TimeDomain) = !iscontinuous(c) - -# workaround for https://github.com/Roger-luo/Moshi.jl/issues/43 -isclock(::Any) = false -issolverstepclock(::Any) = false -iscontinuous(::Any) = false -is_discrete_time_domain(::Any) = false - -function first_clock_tick_time(c, t0) - @match c begin - PeriodicClock(dt) => ceil(t0 / dt) * dt - SolverStepClock() => t0 - ContinuousClock() => error("ContinuousClock() is not a discrete clock") - end -end - -struct IndexedClock{I} - clock::TimeDomain - idx::I -end - -Base.getindex(c::TimeDomain, idx) = IndexedClock(c, idx) - -function canonicalize_indexed_clock(ic::IndexedClock, sol::AbstractTimeseriesSolution) - c = ic.clock - - return @match c begin - PeriodicClock(dt) => ceil(sol.prob.tspan[1] / dt) * dt .+ (ic.idx .- 1) .* dt - SolverStepClock() => begin - ssc_idx = findfirst(eachindex(sol.discretes)) do i - !isa(sol.discretes[i].t, AbstractRange) - end - sol.discretes[ssc_idx].t[ic.idx] - end - ContinuousClock() => sol.t[ic.idx] - end -end +# Include fallback implementations (always available) +include("clock_fallback.jl") diff --git a/src/clock_fallback.jl b/src/clock_fallback.jl new file mode 100644 index 0000000000..6a59ae07cd --- /dev/null +++ b/src/clock_fallback.jl @@ -0,0 +1,111 @@ +# Fallback clock implementations when Moshi is not available +# These provide basic clock functionality without pattern matching + +# Simple struct-based clock definitions (fallback when Moshi not loaded) +abstract type TimeDomain end + +struct ContinuousClock <: TimeDomain end + +struct PeriodicClock <: TimeDomain + dt::Union{Nothing, Float64, Rational{Int}} + phase::Float64 + function PeriodicClock(dt::Union{Nothing, Float64, Rational{Int}}, phase::Float64 = 0.0) + new(dt, phase) + end +end + +struct SolverStepClock <: TimeDomain end + +# for backwards compatibility +const Continuous = ContinuousClock() +(clock::TimeDomain)() = clock + +Base.Broadcast.broadcastable(d::TimeDomain) = Ref(d) + +""" + Clock(dt) + Clock() + +The default periodic clock with tick interval `dt`. If `dt` is left unspecified, it will +be inferred (if possible). +""" +Clock(dt::Union{<:Rational, Float64}; phase = 0.0) = PeriodicClock(dt, phase) +Clock(dt; phase = 0.0) = PeriodicClock(convert(Float64, dt), phase) +Clock(; phase = 0.0) = PeriodicClock(nothing, phase) + +@doc """ + SolverStepClock + +A clock that ticks at each solver step (sometimes referred to as "continuous sample time"). +This clock **does generally not have equidistant tick intervals**, instead, the tick +interval depends on the adaptive step-size selection of the continuous solver, as well as +any continuous event handling. If adaptivity of the solver is turned off and there are no +continuous events, the tick interval will be given by the fixed solver time step `dt`. + +Due to possibly non-equidistant tick intervals, this clock should typically not be used with +discrete-time systems that assume a fixed sample time, such as PID controllers and digital +filters. +""" SolverStepClock + +# Fallback implementations without pattern matching +isclock(c::PeriodicClock) = true +isclock(c::TimeDomain) = false + +issolverstepclock(c::SolverStepClock) = true +issolverstepclock(c::TimeDomain) = false + +iscontinuous(c::ContinuousClock) = true +iscontinuous(c::TimeDomain) = false + +is_discrete_time_domain(c::TimeDomain) = !iscontinuous(c) + +# workaround for fallback when argument is not a TimeDomain +isclock(::Any) = false +issolverstepclock(::Any) = false +iscontinuous(::Any) = false +is_discrete_time_domain(::Any) = false + +function first_clock_tick_time(c::PeriodicClock, t0) + dt = c.dt + ceil(t0 / dt) * dt +end +function first_clock_tick_time(c::SolverStepClock, t0) + t0 +end +function first_clock_tick_time(c::ContinuousClock, t0) + error("ContinuousClock() is not a discrete clock") +end + +struct IndexedClock{I} + clock::TimeDomain + idx::I +end + +Base.getindex(c::TimeDomain, idx) = IndexedClock(c, idx) + +function canonicalize_indexed_clock(ic::IndexedClock, sol::AbstractTimeseriesSolution) + c = ic.clock + + if c isa PeriodicClock + dt = c.dt + return ceil(sol.prob.tspan[1] / dt) * dt .+ (ic.idx .- 1) .* dt + elseif c isa SolverStepClock + ssc_idx = findfirst(eachindex(sol.discretes)) do i + !isa(sol.discretes[i].t, AbstractRange) + end + return sol.discretes[ssc_idx].t[ic.idx] + elseif c isa ContinuousClock + return sol.t[ic.idx] + else + error("Unknown clock type: $(typeof(c))") + end +end + +# Define stub functions that can be overridden by extensions +isclock_moshi(c::TimeDomain) = isclock(c) +issolverstepclock_moshi(c::TimeDomain) = issolverstepclock(c) +iscontinuous_moshi(c::TimeDomain) = iscontinuous(c) +first_clock_tick_time_moshi(c, t0) = first_clock_tick_time(c, t0) +function canonicalize_indexed_clock_moshi(ic::IndexedClock, sol::AbstractTimeseriesSolution) + canonicalize_indexed_clock(ic, sol) +end diff --git a/src/debug.jl b/src/debug.jl index d8c1ae21d5..594db57781 100644 --- a/src/debug.jl +++ b/src/debug.jl @@ -58,7 +58,8 @@ expression. Two common reasons for this issue are: function __init__() Base.Experimental.register_error_hint(DomainError) do io, e - if e isa DomainError && occursin("will only return a complex result if called with a complex argument. Try ", e.msg) + if e isa DomainError && + occursin("will only return a complex result if called with a complex argument. Try ", e.msg) println(io, DOMAINERROR_COMPLEX_MSG) end end diff --git a/src/ensemble/ensemble_solutions.jl b/src/ensemble/ensemble_solutions.jl index ffcb46ea7d..2968107a5e 100644 --- a/src/ensemble/ensemble_solutions.jl +++ b/src/ensemble/ensemble_solutions.jl @@ -184,8 +184,6 @@ end ### Plot Recipes - - function (sol::AbstractEnsembleSolution)(args...; kwargs...) [s(args...; kwargs...) for s in sol] end diff --git a/src/integrator_interface.jl b/src/integrator_interface.jl index becaaf7767..14dbf24e84 100644 --- a/src/integrator_interface.jl +++ b/src/integrator_interface.jl @@ -773,7 +773,6 @@ end Base.length(iter::TimeChoiceIterator) = length(iter.ts) - function step!(integ::DEIntegrator, dt, stop_at_tdt = false) (dt * integ.tdir) < 0 * oneunit(dt) && error("Cannot step backward.") t = integ.t diff --git a/src/solutions/solution_interface.jl b/src/solutions/solution_interface.jl index 835ab7cba8..e393acf5bc 100644 --- a/src/solutions/solution_interface.jl +++ b/src/solutions/solution_interface.jl @@ -212,7 +212,6 @@ used for plotting. plottable_indices(x::AbstractArray) = 1:length(x) plottable_indices(x::Number) = 1 - function diffeq_to_arrays(sol, plot_analytic, denseplot, plotdensity, tspan, vars, tscale, plotat) if tspan === nothing diff --git a/src/utils.jl b/src/utils.jl index 35b892706c..e8c47df25d 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -13,7 +13,6 @@ function numargs(f) end end - numargs(f::ComposedFunction) = numargs(f.inner) """