diff --git a/Project.toml b/Project.toml index 92b029a3a..20ad489bf 100644 --- a/Project.toml +++ b/Project.toml @@ -18,14 +18,11 @@ IteratorInterfaceExtensions = "82899510-4779-5014-852e-03e436cf321d" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" -Moshi = "2e0e35c7-a2e4-4343-998d-7ef72827ed2d" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Preferences = "21216c6a-2e73-6563-6e65-726566657250" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" -RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" -RuntimeGeneratedFunctions = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47" SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961" SciMLStructures = "53ae85a6-f571-4167-b2af-e1d143709226" StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" @@ -37,20 +34,26 @@ ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" MLStyle = "d8e11817-5142-5d16-987a-aa16d5891078" Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a" +Moshi = "2e0e35c7-a2e4-4343-998d-7ef72827ed2d" PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b" PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0" PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d" RCall = "6f49c342-dc21-5d91-9882-a32aef131414" +RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" +RuntimeGeneratedFunctions = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] SciMLBaseChainRulesCoreExt = "ChainRulesCore" SciMLBaseMLStyleExt = "MLStyle" SciMLBaseMakieExt = "Makie" +SciMLBaseMoshiExt = "Moshi" SciMLBasePartialFunctionsExt = "PartialFunctions" SciMLBasePyCallExt = "PyCall" SciMLBasePythonCallExt = "PythonCall" SciMLBaseRCallExt = "RCall" +SciMLBaseRecipesBaseExt = "RecipesBase" +SciMLBaseRuntimeGeneratedFunctionsExt = "RuntimeGeneratedFunctions" SciMLBaseZygoteExt = ["Zygote", "ChainRulesCore"] [compat] @@ -73,7 +76,7 @@ Logging = "1.10" MLStyle = "0.4.17" Makie = "0.20, 0.21, 0.22, 0.23, 0.24" Markdown = "1.10" -Moshi = "0.3" +Moshi = "0.3.7" PartialFunctions = "1.1" PrecompileTools = "1.2" Preferences = "1.3" diff --git a/ext/SciMLBaseMoshiExt.jl b/ext/SciMLBaseMoshiExt.jl new file mode 100644 index 000000000..84a4061bc --- /dev/null +++ b/ext/SciMLBaseMoshiExt.jl @@ -0,0 +1,82 @@ +module SciMLBaseMoshiExt + +if isdefined(Base, :get_extension) + using Moshi.Data: @data + using Moshi.Match: @match +else + using ..Moshi.Data: @data + using ..Moshi.Match: @match +end + +using SciMLBase +using SciMLBase: AbstractTimeseriesSolution, TimeDomain, PeriodicClock, SolverStepClock, + ContinuousClock + +# Enhanced clock predicates using @match when Moshi is available +# These override the fallback implementations with pattern matching +function SciMLBase.isclock_moshi(c::TimeDomain) + @match c begin + PeriodicClock() => true + _ => false + end +end + +function SciMLBase.issolverstepclock_moshi(c::TimeDomain) + @match c begin + SolverStepClock() => true + _ => false + end +end + +function SciMLBase.iscontinuous_moshi(c::TimeDomain) + @match c begin + ContinuousClock() => true + _ => false + end +end + +function SciMLBase.first_clock_tick_time_moshi(c, t0) + @match c begin + PeriodicClock(dt) => ceil(t0 / dt) * dt + SolverStepClock() => t0 + ContinuousClock() => error("ContinuousClock() is not a discrete clock") + end +end + +function SciMLBase.canonicalize_indexed_clock_moshi(ic::SciMLBase.IndexedClock, sol::AbstractTimeseriesSolution) + c = ic.clock + + return @match c begin + PeriodicClock(dt) => ceil(sol.prob.tspan[1] / dt) * dt .+ (ic.idx .- 1) .* dt + SolverStepClock() => begin + ssc_idx = findfirst(eachindex(sol.discretes)) do i + !isa(sol.discretes[i].t, AbstractRange) + end + sol.discretes[ssc_idx].t[ic.idx] + end + ContinuousClock() => sol.t[ic.idx] + end +end + +# Override fallback implementations to use Moshi versions +function SciMLBase.isclock(c::TimeDomain) + SciMLBase.isclock_moshi(c) +end + +function SciMLBase.issolverstepclock(c::TimeDomain) + SciMLBase.issolverstepclock_moshi(c) +end + +function SciMLBase.iscontinuous(c::TimeDomain) + SciMLBase.iscontinuous_moshi(c) +end + +function SciMLBase.first_clock_tick_time(c, t0) + SciMLBase.first_clock_tick_time_moshi(c, t0) +end + +function SciMLBase.canonicalize_indexed_clock(ic::SciMLBase.IndexedClock, sol::AbstractTimeseriesSolution) + SciMLBase.canonicalize_indexed_clock_moshi(ic, sol) +end + +end # module diff --git a/ext/SciMLBaseRecipesBaseExt.jl b/ext/SciMLBaseRecipesBaseExt.jl new file mode 100644 index 000000000..643aa40fd --- /dev/null +++ b/ext/SciMLBaseRecipesBaseExt.jl @@ -0,0 +1,501 @@ +module SciMLBaseRecipesBaseExt + +using SciMLBase +using RecipesBase +import RecursiveArrayTools + +# Need to import the plotting-related functions +import SciMLBase: DEFAULT_PLOT_FUNC, isdenseplot, plottable_indices, interpret_vars, + get_all_timeseries_indexes, ContinuousTimeseries, DiscreteTimeseries, + solution_slice, add_labels!, AbstractTimeseriesSolution, + AbstractEnsembleSolution, + AbstractNoTimeSolution, EnsembleSummary, DEIntegrator, + AbstractSDEIntegrator, + getindepsym_defaultt, getname, hasname, u_n, AbstractDEAlgorithm + +# Recipe for AbstractTimeseriesSolution +@recipe function f(sol::AbstractTimeseriesSolution; + plot_analytic = false, + denseplot = isdenseplot(sol), + plotdensity = min(Int(1e5), + sol.tslocation == 0 ? + (sol.prob isa SciMLBase.AbstractDiscreteProblem ? + max(1000, 100 * length(sol)) : + max(1000, 10 * length(sol))) : + 1000 * sol.tslocation), plotat = nothing, + tspan = nothing, + vars = nothing, idxs = nothing) + if vars !== nothing + Base.depwarn( + "To maintain consistency with solution indexing, keyword argument vars will be removed in a future version. Please use keyword argument idxs instead.", + :f; force = true) + (idxs !== nothing) && + error("Simultaneously using keywords vars and idxs is not supported. Please only use idxs.") + idxs = vars + end + + if plot_analytic && (sol.u_analytic === nothing) + throw(ArgumentError("No analytic solution was found but `plot_analytic` was set to `true`.")) + end + + idxs = idxs === nothing ? plottable_indices(sol.u[1]) : idxs + if !(idxs isa Union{Tuple, AbstractArray}) + vars = interpret_vars([idxs], sol) + else + vars = interpret_vars(idxs, sol) + end + disc_vars = Tuple[] + cont_vars = Tuple[] + for var in vars + tsidxs = union(get_all_timeseries_indexes(sol, var[2]), + get_all_timeseries_indexes(sol, var[3])) + if ContinuousTimeseries() in tsidxs || isempty(tsidxs) + push!(cont_vars, var) + else + push!(disc_vars, var) + end + end + + plot_vecs = [] + labels = [] + + # Handle continuous variables + if !isempty(cont_vars) + int_vars = cont_vars + + if tspan === nothing + if plotat === nothing + if denseplot + # Generate the points from the plot from dense function + start_idx = sol.tslocation == 0 ? 1 : sol.tslocation + end_idx = length(sol.t) + plott = collect(range(sol.t[start_idx], sol.t[end_idx]; length = plotdensity)) + plot_timeseries = sol(plott) + if plot_analytic + plot_analytic_timeseries = [sol.prob.f.analytic(sol.prob.u0, + sol.prob.p, t) + for t in plott] + end + else + plot_timeseries = sol.u + plott = sol.t + if plot_analytic + plot_analytic_timeseries = sol.u_analytic + end + end + else + plot_timeseries = sol(plotat) + plott = plotat + if plot_analytic + plot_analytic_timeseries = [sol.prob.f.analytic(sol.prob.u0, + sol.prob.p, t) for t in plott] + end + end + else + _tspan = tspan isa Number ? (sol.t[1], tspan) : tspan + start_idx = findfirst(x -> x >= _tspan[1], sol.t) + end_idx = findlast(x -> x <= _tspan[2], sol.t) + if denseplot + plott = collect(range(_tspan...; length = plotdensity)) + plot_timeseries = sol(plott) + if plot_analytic + plot_analytic_timeseries = [sol.prob.f.analytic(sol.prob.u0, + sol.prob.p, t) for t in plott] + end + else + if start_idx === nothing + start_idx = 1 + end + if end_idx === nothing + end_idx = length(sol.t) + end + plott = @view sol.t[start_idx:end_idx] + plot_timeseries = @view sol.u[start_idx:end_idx] + if plot_analytic + plot_analytic_timeseries = @view sol.u_analytic[start_idx:end_idx] + end + end + end + + dims = length(int_vars[1]) + for var in int_vars + @assert length(var) == dims + end + # Should check that all have the same dims! + + for i in 2:dims + push!(plot_vecs, []) + end + + labels = String[]# Array{String, 2}(1, length(int_vars)*(1+plot_analytic)) + strs = String[] + varsyms = SciMLBase.variable_symbols(sol) + + for x in int_vars + for j in 2:dims + if denseplot + if (x[j] isa Integer && x[j] == 0) || + isequal(x[j], SciMLBase.getindepsym_defaultt(sol)) + push!(plot_vecs[j - 1], plott) + else + # For the dense plotting case, we use getindex on the timeseries + if plot_timeseries isa AbstractArray + if x[j] isa Integer + # Simple integer indexing + push!(plot_vecs[j - 1], [u[x[j]] for u in plot_timeseries]) + else + # Symbol indexing + push!(plot_vecs[j - 1], Vector(sol(plott; idxs = x[j]))) + end + else + # Single value case + push!(plot_vecs[j - 1], Vector(sol(plott; idxs = x[j]))) + end + end + else # just get values + if x[j] == 0 + push!(plot_vecs[j - 1], plott) + elseif x[j] == 1 && !(eltype(plot_timeseries) <: AbstractArray) + push!(plot_vecs[j - 1], plot_timeseries) + else + if x[j] isa Integer + push!(plot_vecs[j - 1], [u[x[j]] for u in plot_timeseries]) + else + # Symbol indexing + push!(plot_vecs[j - 1], [sol(t, idxs = x[j]) for t in plott]) + end + end + end + + if !isempty(varsyms) && x[j] isa Integer + push!(strs, String(getname(varsyms[x[j]]))) + elseif hasname(x[j]) + push!(strs, String(getname(x[j]))) + else + push!(strs, "u[$(x[j])]") + end + end + add_labels!(labels, x, dims, sol, strs) + end + + if plot_analytic + for x in int_vars + for j in 2:dims + if denseplot + if (x[j] isa Integer && x[j] == 0) || + isequal(x[j], SciMLBase.getindepsym_defaultt(sol)) + push!(plot_vecs[j - 1], plott) + else + push!(plot_vecs[j - 1], + u_n(plot_analytic_timeseries, x[j], sol, plott, + plot_analytic_timeseries)) + end + else # Just get values + if x[j] == 0 + push!(plot_vecs[j - 1], plott) + elseif x[j] == 1 && + !(eltype(plot_analytic_timeseries) <: AbstractArray) + push!(plot_vecs[j - 1], plot_analytic_timeseries) + else + push!(plot_vecs[j - 1], + u_n(plot_analytic_timeseries, x[j], sol, plott, + plot_analytic_timeseries)) + end + end + end + add_labels!(labels, x, dims, sol, strs) + end + end + + xflip --> sol.tdir < 0 + + if denseplot + seriestype --> :path + else + seriestype --> :scatter + end + + # Special case labels when idxs = (:x,:y,:z) or (:x) or [:x,:y] .. + if idxs isa Tuple && (typeof(idxs[1]) == Symbol && typeof(idxs[2]) == Symbol) + xlabel --> idxs[1] + ylabel --> idxs[2] + if length(idxs) > 2 + zlabel --> idxs[3] + end + end + if getindex.(int_vars, 1) == zeros(length(int_vars)) || + getindex.(int_vars, 2) == zeros(length(int_vars)) + xlabel --> "t" + end + + linewidth --> 3 + #xtickfont --> font(11) + #ytickfont --> font(11) + #legendfont --> font(11) + #guidefont --> font(11) + label --> reshape(labels, 1, length(labels)) + (plot_vecs...,) + + # Handle discrete variables + elseif !isempty(disc_vars) + int_vars = disc_vars + + if sol.tslocation != 0 + start_idx = sol.tslocation + else + start_idx = 1 + end + + if tspan === nothing + end_idx = length(sol.t) + else + _tspan = tspan isa Number ? (sol.t[1], tspan) : tspan + end_idx = findlast(x -> x <= _tspan[2], sol.t) + if end_idx === nothing + end_idx = length(sol.t) + end + end + + plott = sol.t[start_idx:end_idx] + plot_timeseries = sol.u[start_idx:end_idx] + + dims = length(int_vars[1]) + for var in int_vars + @assert length(var) == dims + end + + for i in 2:dims + push!(plot_vecs, []) + end + + labels = String[] + strs = String[] + varsyms = SciMLBase.variable_symbols(sol) + + for x in int_vars + for j in 2:dims + if x[j] == 0 + push!(plot_vecs[j - 1], plott) + elseif x[j] == 1 && !(eltype(plot_timeseries) <: AbstractArray) + push!(plot_vecs[j - 1], plot_timeseries) + else + if x[j] isa Integer + push!(plot_vecs[j - 1], [u[x[j]] for u in plot_timeseries]) + else + # Symbol indexing for discrete case + push!(plot_vecs[j - 1], [sol[ti, x[j]] for ti in 1:length(plott)]) + end + end + + if !isempty(varsyms) && x[j] isa Integer + push!(strs, String(getname(varsyms[x[j]]))) + elseif hasname(x[j]) + push!(strs, String(getname(x[j]))) + else + push!(strs, "u[$(x[j])]") + end + end + add_labels!(labels, x, dims, sol, strs) + end + + seriestype --> :steppost + if getindex.(int_vars, 1) == zeros(length(int_vars)) || + getindex.(int_vars, 2) == zeros(length(int_vars)) + xlabel --> "t" + end + + linewidth --> 3 + label --> reshape(labels, 1, length(labels)) + (plot_vecs...,) + end +end + +# Recipe for AbstractEnsembleSolution +@recipe function f(sim::AbstractEnsembleSolution; idxs = nothing, + summarize = true, error_style = :ribbon, ci_type = :quantile, linealpha = 0.4, zorder = 1) + if idxs === nothing + if sim.u[1] isa SciMLBase.AbstractTimeseriesSolution + idxs = 1:length(sim.u[1].u[1]) + else + idxs = 1:length(sim.u[1]) + end + end + + if !(idxs isa Union{Tuple, AbstractArray}) + idxs = [idxs] + end + + if summarize + summary = EnsembleSummary(sim; quantiles = [0.05, 0.95]) + if error_style == :ribbon + ribbon --> (summary.qlow[:, idxs], summary.qhigh[:, idxs]) + elseif error_style == :bars + yerror --> (summary.qlow[:, idxs], summary.qhigh[:, idxs]) + end + summary.t, summary.med[:, idxs] + else + alpha --> linealpha + # Plot all trajectories + for i in eachindex(sim.u) + @series begin + if sim.u[i] isa SciMLBase.AbstractTimeseriesSolution + idxs --> idxs + sim.u[i] + else + # For non-timeseries solutions + sim.u[i][idxs] + end + end + end + end +end + +# Recipe for EnsembleSummary +@recipe function f(sim::EnsembleSummary; idxs = nothing, error_style = :ribbon) + if idxs === nothing + idxs = 1:size(sim.med, 2) + end + + if !(idxs isa Union{Tuple, AbstractArray}) + idxs = [idxs] + end + + if error_style == :ribbon + ribbon --> (sim.qlow[:, idxs], sim.qhigh[:, idxs]) + elseif error_style == :bars + yerror --> (sim.qlow[:, idxs], sim.qhigh[:, idxs]) + end + sim.t, sim.med[:, idxs] +end + +# Recipe for DEIntegrator +@recipe function f(integrator::DEIntegrator; + denseplot = (integrator.opts.calck || + integrator isa AbstractSDEIntegrator) && + integrator.iter > 0, + plotdensity = 10, + plot_analytic = false, vars = nothing, idxs = nothing) + if vars !== nothing + Base.depwarn( + "To maintain consistency with solution indexing, keyword argument vars will be removed in a future version. Please use keyword argument idxs instead.", + :f; force = true) + (idxs !== nothing) && + error("Simultaneously using keywords vars and idxs is not supported. Please only use idxs.") + idxs = vars + end + + int_vars = interpret_vars(idxs, integrator.sol) + + if denseplot + # Generate the points from the plot from dense function + plott = collect(range(integrator.tprev, integrator.t; length = plotdensity)) + if plot_analytic + plot_analytic_timeseries = [integrator.sol.prob.f.analytic( + integrator.sol.prob.u0, + integrator.sol.prob.p, + t) for t in plott] + end + else + plott = nothing + end + + dims = length(int_vars[1]) + for var in int_vars + @assert length(var) == dims + end + # Should check that all have the same dims! + + plot_vecs = [] + for i in 2:dims + push!(plot_vecs, []) + end + + labels = String[]# Array{String, 2}(1, length(int_vars)*(1+plot_analytic)) + strs = String[] + varsyms = SciMLBase.variable_symbols(integrator) + + for x in int_vars + for j in 2:dims + if denseplot + if (x[j] isa Integer && x[j] == 0) || + isequal(x[j], getindepsym_defaultt(integrator)) + push!(plot_vecs[j - 1], plott) + else + push!(plot_vecs[j - 1], Vector(integrator(plott; idxs = x[j]))) + end + else # just get values + if x[j] == 0 + push!(plot_vecs[j - 1], integrator.t) + elseif x[j] == 1 && !(integrator.u isa AbstractArray) + push!(plot_vecs[j - 1], integrator.u) + else + push!(plot_vecs[j - 1], integrator.u[x[j]]) + end + end + + if !isempty(varsyms) && x[j] isa Integer + push!(strs, String(getname(varsyms[x[j]]))) + elseif hasname(x[j]) + push!(strs, String(getname(x[j]))) + else + push!(strs, "u[$(x[j])]") + end + end + add_labels!(labels, x, dims, integrator.sol, strs) + end + + if plot_analytic + for x in int_vars + for j in 1:dims + if denseplot + push!(plot_vecs[j], + u_n(plot_timeseries, x[j], sol, plott, plot_timeseries)) + else # Just get values + if x[j] == 0 + push!(plot_vecs[j], integrator.t) + elseif x[j] == 1 && !(integrator.u isa AbstractArray) + push!(plot_vecs[j], + integrator.sol.prob.f(Val{:analytic}, integrator.t, + integrator.sol[1])) + else + push!(plot_vecs[j], + integrator.sol.prob.f(Val{:analytic}, integrator.t, + integrator.sol[1])[x[j]]) + end + end + end + add_labels!(labels, x, dims, integrator.sol, strs) + end + end + + xflip --> integrator.tdir < 0 + + if denseplot + seriestype --> :path + else + seriestype --> :scatter + end + + # Special case labels when idxs = (:x,:y,:z) or (:x) or [:x,:y] .. + if idxs isa Tuple && (typeof(idxs[1]) == Symbol && typeof(idxs[2]) == Symbol) + xlabel --> idxs[1] + ylabel --> idxs[2] + if length(idxs) > 2 + zlabel --> idxs[3] + end + end + if getindex.(int_vars, 1) == zeros(length(int_vars)) || + getindex.(int_vars, 2) == zeros(length(int_vars)) + xlabel --> "t" + end + + linewidth --> 3 + #xtickfont --> font(11) + #ytickfont --> font(11) + #legendfont --> font(11) + #guidefont --> font(11) + label --> reshape(labels, 1, length(labels)) + (plot_vecs...,) +end + +end diff --git a/ext/SciMLBaseRuntimeGeneratedFunctionsExt.jl b/ext/SciMLBaseRuntimeGeneratedFunctionsExt.jl new file mode 100644 index 000000000..f44d2ec7c --- /dev/null +++ b/ext/SciMLBaseRuntimeGeneratedFunctionsExt.jl @@ -0,0 +1,20 @@ +module SciMLBaseRuntimeGeneratedFunctionsExt + +using SciMLBase +using RuntimeGeneratedFunctions + +function SciMLBase.numargs(f::RuntimeGeneratedFunctions.RuntimeGeneratedFunction{ + T, + V, + W, + I +}) where { + T, + V, + W, + I +} + (length(T),) +end + +end diff --git a/src/SciMLBase.jl b/src/SciMLBase.jl index 49d9e246d..d16148f3f 100644 --- a/src/SciMLBase.jl +++ b/src/SciMLBase.jl @@ -4,7 +4,7 @@ if isdefined(Base, :Experimental) && @eval Base.Experimental.@max_methods 1 end using ConstructionBase -using RecipesBase, RecursiveArrayTools +using RecursiveArrayTools using SciMLStructures using SymbolicIndexingInterface using DocStringExtensions @@ -19,12 +19,10 @@ import Logging, ArrayInterface import IteratorInterfaceExtensions import CommonSolve: solve, init, step!, solve! import FunctionWrappersWrappers -import RuntimeGeneratedFunctions import EnumX import ADTypes: ADTypes, AbstractADType import Accessors: @set, @reset, @delete, @insert -using Moshi.Data: @data -using Moshi.Match: @match +# Moshi moved to extension for load time optimization import StaticArraysCore import Adapt: adapt_structure, adapt diff --git a/src/clock.jl b/src/clock.jl index a417ebea4..325d520b9 100644 --- a/src/clock.jl +++ b/src/clock.jl @@ -1,94 +1,7 @@ -@data Clocks begin - ContinuousClock - struct PeriodicClock - dt::Union{Nothing, Float64, Rational{Int}} - phase::Float64 = 0.0 - end - SolverStepClock -end +# Clock system implementation +# +# When Moshi is available as an extension, it provides advanced @data/@match functionality. +# When Moshi is not loaded, we fall back to simple struct-based implementations. -# for backwards compatibility -const TimeDomain = Clocks.Type -using .Clocks: ContinuousClock, PeriodicClock, SolverStepClock -const Continuous = ContinuousClock() -(clock::TimeDomain)() = clock - -Base.Broadcast.broadcastable(d::TimeDomain) = Ref(d) - -""" - Clock(dt) - Clock() - -The default periodic clock with tick interval `dt`. If `dt` is left unspecified, it will -be inferred (if possible). -""" -Clock(dt::Union{<:Rational, Float64}; phase = 0.0) = PeriodicClock(dt, phase) -Clock(dt; phase = 0.0) = PeriodicClock(convert(Float64, dt), phase) -Clock(; phase = 0.0) = PeriodicClock(nothing, phase) - -@doc """ - SolverStepClock - -A clock that ticks at each solver step (sometimes referred to as "continuous sample time"). -This clock **does generally not have equidistant tick intervals**, instead, the tick -interval depends on the adaptive step-size selection of the continuous solver, as well as -any continuous event handling. If adaptivity of the solver is turned off and there are no -continuous events, the tick interval will be given by the fixed solver time step `dt`. - -Due to possibly non-equidistant tick intervals, this clock should typically not be used with -discrete-time systems that assume a fixed sample time, such as PID controllers and digital -filters. -""" SolverStepClock - -isclock(c::TimeDomain) = @match c begin - PeriodicClock() => true - _ => false -end - -issolverstepclock(c::TimeDomain) = @match c begin - SolverStepClock() => true - _ => false -end - -iscontinuous(c::TimeDomain) = @match c begin - ContinuousClock() => true - _ => false -end - -is_discrete_time_domain(c::TimeDomain) = !iscontinuous(c) - -# workaround for https://github.com/Roger-luo/Moshi.jl/issues/43 -isclock(::Any) = false -issolverstepclock(::Any) = false -iscontinuous(::Any) = false -is_discrete_time_domain(::Any) = false - -function first_clock_tick_time(c, t0) - @match c begin - PeriodicClock(dt) => ceil(t0 / dt) * dt - SolverStepClock() => t0 - ContinuousClock() => error("ContinuousClock() is not a discrete clock") - end -end - -struct IndexedClock{I} - clock::TimeDomain - idx::I -end - -Base.getindex(c::TimeDomain, idx) = IndexedClock(c, idx) - -function canonicalize_indexed_clock(ic::IndexedClock, sol::AbstractTimeseriesSolution) - c = ic.clock - - return @match c begin - PeriodicClock(dt) => ceil(sol.prob.tspan[1] / dt) * dt .+ (ic.idx .- 1) .* dt - SolverStepClock() => begin - ssc_idx = findfirst(eachindex(sol.discretes)) do i - !isa(sol.discretes[i].t, AbstractRange) - end - sol.discretes[ssc_idx].t[ic.idx] - end - ContinuousClock() => sol.t[ic.idx] - end -end +# Include fallback implementations (always available) +include("clock_fallback.jl") diff --git a/src/clock_fallback.jl b/src/clock_fallback.jl new file mode 100644 index 000000000..6a59ae07c --- /dev/null +++ b/src/clock_fallback.jl @@ -0,0 +1,111 @@ +# Fallback clock implementations when Moshi is not available +# These provide basic clock functionality without pattern matching + +# Simple struct-based clock definitions (fallback when Moshi not loaded) +abstract type TimeDomain end + +struct ContinuousClock <: TimeDomain end + +struct PeriodicClock <: TimeDomain + dt::Union{Nothing, Float64, Rational{Int}} + phase::Float64 + function PeriodicClock(dt::Union{Nothing, Float64, Rational{Int}}, phase::Float64 = 0.0) + new(dt, phase) + end +end + +struct SolverStepClock <: TimeDomain end + +# for backwards compatibility +const Continuous = ContinuousClock() +(clock::TimeDomain)() = clock + +Base.Broadcast.broadcastable(d::TimeDomain) = Ref(d) + +""" + Clock(dt) + Clock() + +The default periodic clock with tick interval `dt`. If `dt` is left unspecified, it will +be inferred (if possible). +""" +Clock(dt::Union{<:Rational, Float64}; phase = 0.0) = PeriodicClock(dt, phase) +Clock(dt; phase = 0.0) = PeriodicClock(convert(Float64, dt), phase) +Clock(; phase = 0.0) = PeriodicClock(nothing, phase) + +@doc """ + SolverStepClock + +A clock that ticks at each solver step (sometimes referred to as "continuous sample time"). +This clock **does generally not have equidistant tick intervals**, instead, the tick +interval depends on the adaptive step-size selection of the continuous solver, as well as +any continuous event handling. If adaptivity of the solver is turned off and there are no +continuous events, the tick interval will be given by the fixed solver time step `dt`. + +Due to possibly non-equidistant tick intervals, this clock should typically not be used with +discrete-time systems that assume a fixed sample time, such as PID controllers and digital +filters. +""" SolverStepClock + +# Fallback implementations without pattern matching +isclock(c::PeriodicClock) = true +isclock(c::TimeDomain) = false + +issolverstepclock(c::SolverStepClock) = true +issolverstepclock(c::TimeDomain) = false + +iscontinuous(c::ContinuousClock) = true +iscontinuous(c::TimeDomain) = false + +is_discrete_time_domain(c::TimeDomain) = !iscontinuous(c) + +# workaround for fallback when argument is not a TimeDomain +isclock(::Any) = false +issolverstepclock(::Any) = false +iscontinuous(::Any) = false +is_discrete_time_domain(::Any) = false + +function first_clock_tick_time(c::PeriodicClock, t0) + dt = c.dt + ceil(t0 / dt) * dt +end +function first_clock_tick_time(c::SolverStepClock, t0) + t0 +end +function first_clock_tick_time(c::ContinuousClock, t0) + error("ContinuousClock() is not a discrete clock") +end + +struct IndexedClock{I} + clock::TimeDomain + idx::I +end + +Base.getindex(c::TimeDomain, idx) = IndexedClock(c, idx) + +function canonicalize_indexed_clock(ic::IndexedClock, sol::AbstractTimeseriesSolution) + c = ic.clock + + if c isa PeriodicClock + dt = c.dt + return ceil(sol.prob.tspan[1] / dt) * dt .+ (ic.idx .- 1) .* dt + elseif c isa SolverStepClock + ssc_idx = findfirst(eachindex(sol.discretes)) do i + !isa(sol.discretes[i].t, AbstractRange) + end + return sol.discretes[ssc_idx].t[ic.idx] + elseif c isa ContinuousClock + return sol.t[ic.idx] + else + error("Unknown clock type: $(typeof(c))") + end +end + +# Define stub functions that can be overridden by extensions +isclock_moshi(c::TimeDomain) = isclock(c) +issolverstepclock_moshi(c::TimeDomain) = issolverstepclock(c) +iscontinuous_moshi(c::TimeDomain) = iscontinuous(c) +first_clock_tick_time_moshi(c, t0) = first_clock_tick_time(c, t0) +function canonicalize_indexed_clock_moshi(ic::IndexedClock, sol::AbstractTimeseriesSolution) + canonicalize_indexed_clock(ic, sol) +end diff --git a/src/debug.jl b/src/debug.jl index d8c1ae21d..594db5778 100644 --- a/src/debug.jl +++ b/src/debug.jl @@ -58,7 +58,8 @@ expression. Two common reasons for this issue are: function __init__() Base.Experimental.register_error_hint(DomainError) do io, e - if e isa DomainError && occursin("will only return a complex result if called with a complex argument. Try ", e.msg) + if e isa DomainError && + occursin("will only return a complex result if called with a complex argument. Try ", e.msg) println(io, DOMAINERROR_COMPLEX_MSG) end end diff --git a/src/ensemble/ensemble_solutions.jl b/src/ensemble/ensemble_solutions.jl index ab25d1181..2968107a5 100644 --- a/src/ensemble/ensemble_solutions.jl +++ b/src/ensemble/ensemble_solutions.jl @@ -184,77 +184,6 @@ end ### Plot Recipes -@recipe function f(sim::AbstractEnsembleSolution; - zcolors = sim.u isa AbstractArray ? fill(nothing, length(sim.u)) : - nothing, - trajectories = eachindex(sim)) - for i in trajectories - size(sim.u[i].u, 1) == 0 && continue - @series begin - legend := false - xlims --> (-Inf, Inf) - ylims --> (-Inf, Inf) - zlims --> (-Inf, Inf) - marker_z --> zcolors[i] - sim.u[i] - end - end -end - -@recipe function f(sim::EnsembleSummary; - idxs = sim.u.u[1] isa AbstractArray ? eachindex(sim.u.u[1]) : - 1, - error_style = :ribbon, ci_type = :quantile) - if ci_type == :SEM - if sim.u.u[1] isa AbstractArray - u = vecarr_to_vectors(sim.u) - else - u = [sim.u.u] - end - if sim.u.u[1] isa AbstractArray - ci_low = vecarr_to_vectors(VectorOfArray([sqrt.(sim.v.u[i] / sim.num_monte) .* - 1.96 for i in 1:length(sim.v)])) - ci_high = ci_low - else - ci_low = [[sqrt(sim.v.u[i] / length(sim.num_monte)) .* 1.96 - for i in 1:length(sim.v)]] - ci_high = ci_low - end - elseif ci_type == :quantile - if sim.med.u[1] isa AbstractArray - u = vecarr_to_vectors(sim.med) - else - u = [sim.med.u] - end - if sim.u.u[1] isa AbstractArray - ci_low = u - vecarr_to_vectors(sim.qlow) - ci_high = vecarr_to_vectors(sim.qhigh) - u - else - ci_low = [u[1] - sim.qlow.u] - ci_high = [sim.qhigh.u - u[1]] - end - else - error("ci_type choice not valid. Must be `:SEM` or `:quantile`") - end - for i in idxs - @series begin - legend --> false - linewidth --> 3 - fillalpha --> 0.2 - if error_style == :ribbon - ribbon --> (ci_low[i], ci_high[i]) - elseif error_style == :bars - yerror --> (ci_low[i], ci_high[i]) - elseif error_style == :none - nothing - else - error("error_style not recognized") - end - sim.t, u[i] - end - end -end - function (sol::AbstractEnsembleSolution)(args...; kwargs...) [s(args...; kwargs...) for s in sol] end diff --git a/src/integrator_interface.jl b/src/integrator_interface.jl index 3b1376285..14dbf24e8 100644 --- a/src/integrator_interface.jl +++ b/src/integrator_interface.jl @@ -773,136 +773,6 @@ end Base.length(iter::TimeChoiceIterator) = length(iter.ts) -@recipe function f(integrator::DEIntegrator; - denseplot = (integrator.opts.calck || - integrator isa AbstractSDEIntegrator) && - integrator.iter > 0, - plotdensity = 10, - plot_analytic = false, vars = nothing, idxs = nothing) - if vars !== nothing - Base.depwarn( - "To maintain consistency with solution indexing, keyword argument vars will be removed in a future version. Please use keyword argument idxs instead.", - :f; force = true) - (idxs !== nothing) && - error("Simultaneously using keywords vars and idxs is not supported. Please only use idxs.") - idxs = vars - end - - int_vars = interpret_vars(idxs, integrator.sol) - - if denseplot - # Generate the points from the plot from dense function - plott = collect(range(integrator.tprev, integrator.t; length = plotdensity)) - if plot_analytic - plot_analytic_timeseries = [integrator.sol.prob.f.analytic( - integrator.sol.prob.u0, - integrator.sol.prob.p, - t) for t in plott] - end - else - plott = nothing - end - - dims = length(int_vars[1]) - for var in int_vars - @assert length(var) == dims - end - # Should check that all have the same dims! - - plot_vecs = [] - for i in 2:dims - push!(plot_vecs, []) - end - - labels = String[]# Array{String, 2}(1, length(int_vars)*(1+plot_analytic)) - strs = String[] - varsyms = variable_symbols(integrator) - @show plott - - for x in int_vars - for j in 2:dims - if denseplot - if (x[j] isa Integer && x[j] == 0) || - isequal(x[j], getindepsym_defaultt(integrator)) - push!(plot_vecs[j - 1], plott) - else - push!(plot_vecs[j - 1], Vector(integrator(plott; idxs = x[j]))) - end - else # just get values - if x[j] == 0 - push!(plot_vecs[j - 1], integrator.t) - elseif x[j] == 1 && !(integrator.u isa AbstractArray) - push!(plot_vecs[j - 1], integrator.u) - else - push!(plot_vecs[j - 1], integrator.u[x[j]]) - end - end - - if !isempty(varsyms) && x[j] isa Integer - push!(strs, String(getname(varsyms[x[j]]))) - elseif hasname(x[j]) - push!(strs, String(getname(x[j]))) - else - push!(strs, "u[$(x[j])]") - end - end - add_labels!(labels, x, dims, integrator.sol, strs) - end - - if plot_analytic - for x in int_vars - for j in 1:dims - if denseplot - push!(plot_vecs[j], - u_n(plot_timeseries, x[j], sol, plott, plot_timeseries)) - else # Just get values - if x[j] == 0 - push!(plot_vecs[j], integrator.t) - elseif x[j] == 1 && !(integrator.u isa AbstractArray) - push!(plot_vecs[j], - integrator.sol.prob.f(Val{:analytic}, integrator.t, - integrator.sol[1])) - else - push!(plot_vecs[j], - integrator.sol.prob.f(Val{:analytic}, integrator.t, - integrator.sol[1])[x[j]]) - end - end - end - add_labels!(labels, x, dims, integrator.sol, strs) - end - end - - xflip --> integrator.tdir < 0 - - if denseplot - seriestype --> :path - else - seriestype --> :scatter - end - - # Special case labels when idxs = (:x,:y,:z) or (:x) or [:x,:y] ... - if idxs isa Tuple && (typeof(idxs[1]) == Symbol && typeof(idxs[2]) == Symbol) - xlabel --> idxs[1] - ylabel --> idxs[2] - if length(idxs) > 2 - zlabel --> idxs[3] - end - end - if getindex.(int_vars, 1) == zeros(length(int_vars)) || - getindex.(int_vars, 2) == zeros(length(int_vars)) - xlabel --> "t" - end - - linewidth --> 3 - #xtickfont --> font(11) - #ytickfont --> font(11) - #legendfont --> font(11) - #guidefont --> font(11) - label --> reshape(labels, 1, length(labels)) - (plot_vecs...,) -end - function step!(integ::DEIntegrator, dt, stop_at_tdt = false) (dt * integ.tdir) < 0 * oneunit(dt) && error("Cannot step backward.") t = integ.t diff --git a/src/solutions/solution_interface.jl b/src/solutions/solution_interface.jl index f08a73ce9..e393acf5b 100644 --- a/src/solutions/solution_interface.jl +++ b/src/solutions/solution_interface.jl @@ -212,175 +212,6 @@ used for plotting. plottable_indices(x::AbstractArray) = 1:length(x) plottable_indices(x::Number) = 1 -@recipe function f(sol::AbstractTimeseriesSolution; - plot_analytic = false, - denseplot = isdenseplot(sol), - plotdensity = min(Int(1e5), - sol.tslocation == 0 ? - (sol.prob isa AbstractDiscreteProblem ? - max(1000, 100 * length(sol)) : - max(1000, 10 * length(sol))) : - 1000 * sol.tslocation), plotat = nothing, - tspan = nothing, - vars = nothing, idxs = nothing) - if vars !== nothing - Base.depwarn( - "To maintain consistency with solution indexing, keyword argument vars will be removed in a future version. Please use keyword argument idxs instead.", - :f; force = true) - (idxs !== nothing) && - error("Simultaneously using keywords vars and idxs is not supported. Please only use idxs.") - idxs = vars - end - - if plot_analytic && (sol.u_analytic === nothing) - throw(ArgumentError("No analytic solution was found but `plot_analytic` was set to `true`.")) - end - - idxs = idxs === nothing ? plottable_indices(sol.u[1]) : idxs - if !(idxs isa Union{Tuple, AbstractArray}) - vars = interpret_vars([idxs], sol) - else - vars = interpret_vars(idxs, sol) - end - disc_vars = Tuple[] - cont_vars = Tuple[] - for var in vars - tsidxs = union(get_all_timeseries_indexes(sol, var[2]), - get_all_timeseries_indexes(sol, var[3])) - if ContinuousTimeseries() in tsidxs || isempty(tsidxs) - push!(cont_vars, var) - else - push!(disc_vars, (var..., only(tsidxs))) - end - end - idxs = identity.(cont_vars) - vars = identity.(cont_vars) - tdir = sign(sol.t[end] - sol.t[1]) - xflip --> tdir < 0 - seriestype --> :path - - @series begin - if idxs isa Union{AbstractArray, Tuple} && isempty(idxs) - label --> nothing - ([], []) - else - tscale = get(plotattributes, :xscale, :identity) - plot_vecs, - labels = diffeq_to_arrays(sol, plot_analytic, denseplot, - plotdensity, tspan, vars, tscale, plotat) - - # Special case labels when idxs = (:x,:y,:z) or (:x) or [:x,:y] ... - if idxs isa Tuple && vars[1][1] === DEFAULT_PLOT_FUNC - val = hasname(vars[1][2]) ? String(getname(vars[1][2])) : vars[1][2] - if val isa Integer - if val == 0 - val = "t" - else - val = "u[$val]" - end - end - xguide --> val - val = hasname(vars[1][3]) ? String(getname(vars[1][3])) : vars[1][3] - if val isa Integer - if val == 0 - val = "t" - else - val = "u[$val]" - end - end - yguide --> val - if length(idxs) > 2 - val = hasname(vars[1][4]) ? String(getname(vars[1][4])) : vars[1][4] - if val isa Integer - if val == 0 - val = "t" - else - val = "u[$val]" - end - end - zguide --> val - end - end - - if (!any(!isequal(NotSymbolic()), symbolic_type.(getindex.(vars, 1))) && - getindex.(vars, 1) == zeros(length(vars))) || - (!any(!isequal(NotSymbolic()), symbolic_type.(getindex.(vars, 2))) && - getindex.(vars, 2) == zeros(length(vars))) || - all(t -> Symbol(t) == getindepsym_defaultt(sol), getindex.(vars, 1)) || - all(t -> Symbol(t) == getindepsym_defaultt(sol), getindex.(vars, 2)) - xguide --> "$(getindepsym_defaultt(sol))" - end - if length(vars[1]) >= 3 && - ((!any(!isequal(NotSymbolic()), symbolic_type.(getindex.(vars, 3))) && - getindex.(vars, 3) == zeros(length(vars))) || - all(t -> Symbol(t) == getindepsym_defaultt(sol), getindex.(vars, 3))) - yguide --> "$(getindepsym_defaultt(sol))" - end - if length(vars[1]) >= 4 && - ((!any(!isequal(NotSymbolic()), symbolic_type.(getindex.(vars, 4))) && - getindex.(vars, 4) == zeros(length(vars))) || - all(t -> Symbol(t) == getindepsym_defaultt(sol), getindex.(vars, 4))) - zguide --> "$(getindepsym_defaultt(sol))" - end - - if (!any(!isequal(NotSymbolic()), symbolic_type.(getindex.(vars, 2))) && - getindex.(vars, 2) == zeros(length(vars))) || - all(t -> Symbol(t) == getindepsym_defaultt(sol), getindex.(vars, 2)) - if tspan === nothing - if tdir > 0 - xlims --> (sol.t[1], sol.t[end]) - else - xlims --> (sol.t[end], sol.t[1]) - end - else - xlims --> (tspan[1], tspan[end]) - end - end - - label --> reshape(labels, 1, length(labels)) - (plot_vecs...,) - end - end - for (func, xvar, yvar, tsidx) in disc_vars - partition = sol.discretes[tsidx] - ts = current_time(partition) - if tspan !== nothing - tstart = searchsortedfirst(ts, tspan[1]) - tend = searchsortedlast(ts, tspan[2]) - if tstart == lastindex(ts) + 1 || tend == firstindex(ts) - 1 - continue - end - else - tstart = firstindex(ts) - tend = lastindex(ts) - end - ts = ts[tstart:tend] - - if symbolic_type(xvar) == NotSymbolic() && xvar == 0 - xvar = only(independent_variable_symbols(sol)) - end - xvals = sol(ts; idxs = xvar).u - # xvals = getsym(sol, xvar)(sol, tstart:tend) - yvals = getp(sol, yvar)(sol, tstart:tend) - tmpvals = map(func, xvals, yvals) - xvals = getindex.(tmpvals, 1) - yvals = getindex.(tmpvals, 2) - # Scatterplot of points - @series begin - seriestype := :line - linestyle --> :dash - markershape --> :o - markersize --> repeat([2, 0], length(ts) - 1) - markeralpha --> repeat([1, 0], length(ts) - 1) - label --> string(hasname(yvar) ? getname(yvar) : yvar) - - x = vec([xvals[1:(end - 1)]'; xvals[2:end]']) - y = repeat(yvals, inner = 2)[1:(end - 1)] - x, y - end - end -end - function diffeq_to_arrays(sol, plot_analytic, denseplot, plotdensity, tspan, vars, tscale, plotat) if tspan === nothing diff --git a/src/utils.jl b/src/utils.jl index ecded5af1..e8c47df25 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -13,20 +13,6 @@ function numargs(f) end end -function numargs(f::RuntimeGeneratedFunctions.RuntimeGeneratedFunction{ - T, - V, - W, - I -}) where { - T, - V, - W, - I -} - (length(T),) -end - numargs(f::ComposedFunction) = numargs(f.inner) """