diff --git a/Project.toml b/Project.toml index 92b029a3a..ddbef1930 100644 --- a/Project.toml +++ b/Project.toml @@ -22,10 +22,8 @@ Moshi = "2e0e35c7-a2e4-4343-998d-7ef72827ed2d" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Preferences = "21216c6a-2e73-6563-6e65-726566657250" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" -RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" -RuntimeGeneratedFunctions = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47" SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961" SciMLStructures = "53ae85a6-f571-4167-b2af-e1d143709226" StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" @@ -41,6 +39,7 @@ PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b" PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0" PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d" RCall = "6f49c342-dc21-5d91-9882-a32aef131414" +RuntimeGeneratedFunctions = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] @@ -51,6 +50,8 @@ SciMLBasePartialFunctionsExt = "PartialFunctions" SciMLBasePyCallExt = "PyCall" SciMLBasePythonCallExt = "PythonCall" SciMLBaseRCallExt = "RCall" +SciMLBaseRecipesBaseExt = "RecipesBase" +SciMLBaseRuntimeGeneratedFunctionsExt = "RuntimeGeneratedFunctions" SciMLBaseZygoteExt = ["Zygote", "ChainRulesCore"] [compat] @@ -81,7 +82,6 @@ Printf = "1.10" PyCall = "1.96" PythonCall = "0.9.15" RCall = "0.14.0" -RecipesBase = "1.3.4" RecursiveArrayTools = "3.35" Reexport = "1" RuntimeGeneratedFunctions = "0.5.12" diff --git a/ext/SciMLBaseRecipesBaseExt.jl b/ext/SciMLBaseRecipesBaseExt.jl new file mode 100644 index 000000000..5028f9dd7 --- /dev/null +++ b/ext/SciMLBaseRecipesBaseExt.jl @@ -0,0 +1,500 @@ +module SciMLBaseRecipesBaseExt + +using SciMLBase +using RecipesBase +import RecursiveArrayTools + +# Need to import the plotting-related functions +import SciMLBase: DEFAULT_PLOT_FUNC, isdenseplot, plottable_indices, interpret_vars, + get_all_timeseries_indexes, ContinuousTimeseries, DiscreteTimeseries, + solution_slice, add_labels!, AbstractTimeseriesSolution, AbstractEnsembleSolution, + AbstractNoTimeSolution, EnsembleSummary, DEIntegrator, AbstractSDEIntegrator, + getindepsym_defaultt, getname, hasname, u_n, AbstractDEAlgorithm + +# Recipe for AbstractTimeseriesSolution +@recipe function f(sol::AbstractTimeseriesSolution; + plot_analytic = false, + denseplot = isdenseplot(sol), + plotdensity = min(Int(1e5), + sol.tslocation == 0 ? + (sol.prob isa SciMLBase.AbstractDiscreteProblem ? + max(1000, 100 * length(sol)) : + max(1000, 10 * length(sol))) : + 1000 * sol.tslocation), plotat = nothing, + tspan = nothing, + vars = nothing, idxs = nothing) + if vars !== nothing + Base.depwarn( + "To maintain consistency with solution indexing, keyword argument vars will be removed in a future version. Please use keyword argument idxs instead.", + :f; force = true) + (idxs !== nothing) && + error("Simultaneously using keywords vars and idxs is not supported. Please only use idxs.") + idxs = vars + end + + if plot_analytic && (sol.u_analytic === nothing) + throw(ArgumentError("No analytic solution was found but `plot_analytic` was set to `true`.")) + end + + idxs = idxs === nothing ? plottable_indices(sol.u[1]) : idxs + if !(idxs isa Union{Tuple, AbstractArray}) + vars = interpret_vars([idxs], sol) + else + vars = interpret_vars(idxs, sol) + end + disc_vars = Tuple[] + cont_vars = Tuple[] + for var in vars + tsidxs = union(get_all_timeseries_indexes(sol, var[2]), + get_all_timeseries_indexes(sol, var[3])) + if ContinuousTimeseries() in tsidxs || isempty(tsidxs) + push!(cont_vars, var) + else + push!(disc_vars, var) + end + end + + plot_vecs = [] + labels = [] + + # Handle continuous variables + if !isempty(cont_vars) + int_vars = cont_vars + + if tspan === nothing + if plotat === nothing + if denseplot + # Generate the points from the plot from dense function + start_idx = sol.tslocation == 0 ? 1 : sol.tslocation + end_idx = length(sol.t) + plott = collect(range(sol.t[start_idx], sol.t[end_idx]; length = plotdensity)) + plot_timeseries = sol(plott) + if plot_analytic + plot_analytic_timeseries = [sol.prob.f.analytic(sol.prob.u0, + sol.prob.p, t) + for t in plott] + end + else + plot_timeseries = sol.u + plott = sol.t + if plot_analytic + plot_analytic_timeseries = sol.u_analytic + end + end + else + plot_timeseries = sol(plotat) + plott = plotat + if plot_analytic + plot_analytic_timeseries = [sol.prob.f.analytic(sol.prob.u0, + sol.prob.p, t) for t in plott] + end + end + else + _tspan = tspan isa Number ? (sol.t[1], tspan) : tspan + start_idx = findfirst(x -> x >= _tspan[1], sol.t) + end_idx = findlast(x -> x <= _tspan[2], sol.t) + if denseplot + plott = collect(range(_tspan...; length = plotdensity)) + plot_timeseries = sol(plott) + if plot_analytic + plot_analytic_timeseries = [sol.prob.f.analytic(sol.prob.u0, + sol.prob.p, t) for t in plott] + end + else + if start_idx === nothing + start_idx = 1 + end + if end_idx === nothing + end_idx = length(sol.t) + end + plott = @view sol.t[start_idx:end_idx] + plot_timeseries = @view sol.u[start_idx:end_idx] + if plot_analytic + plot_analytic_timeseries = @view sol.u_analytic[start_idx:end_idx] + end + end + end + + dims = length(int_vars[1]) + for var in int_vars + @assert length(var) == dims + end + # Should check that all have the same dims! + + for i in 2:dims + push!(plot_vecs, []) + end + + labels = String[]# Array{String, 2}(1, length(int_vars)*(1+plot_analytic)) + strs = String[] + varsyms = SciMLBase.variable_symbols(sol) + + for x in int_vars + for j in 2:dims + if denseplot + if (x[j] isa Integer && x[j] == 0) || + isequal(x[j], SciMLBase.getindepsym_defaultt(sol)) + push!(plot_vecs[j - 1], plott) + else + # For the dense plotting case, we use getindex on the timeseries + if plot_timeseries isa AbstractArray + if x[j] isa Integer + # Simple integer indexing + push!(plot_vecs[j - 1], [u[x[j]] for u in plot_timeseries]) + else + # Symbol indexing + push!(plot_vecs[j - 1], Vector(sol(plott; idxs = x[j]))) + end + else + # Single value case + push!(plot_vecs[j - 1], Vector(sol(plott; idxs = x[j]))) + end + end + else # just get values + if x[j] == 0 + push!(plot_vecs[j - 1], plott) + elseif x[j] == 1 && !(eltype(plot_timeseries) <: AbstractArray) + push!(plot_vecs[j - 1], plot_timeseries) + else + if x[j] isa Integer + push!(plot_vecs[j - 1], [u[x[j]] for u in plot_timeseries]) + else + # Symbol indexing + push!(plot_vecs[j - 1], [sol(t, idxs = x[j]) for t in plott]) + end + end + end + + if !isempty(varsyms) && x[j] isa Integer + push!(strs, String(getname(varsyms[x[j]]))) + elseif hasname(x[j]) + push!(strs, String(getname(x[j]))) + else + push!(strs, "u[$(x[j])]") + end + end + add_labels!(labels, x, dims, sol, strs) + end + + if plot_analytic + for x in int_vars + for j in 2:dims + if denseplot + if (x[j] isa Integer && x[j] == 0) || + isequal(x[j], SciMLBase.getindepsym_defaultt(sol)) + push!(plot_vecs[j - 1], plott) + else + push!(plot_vecs[j - 1], + u_n(plot_analytic_timeseries, x[j], sol, plott, + plot_analytic_timeseries)) + end + else # Just get values + if x[j] == 0 + push!(plot_vecs[j - 1], plott) + elseif x[j] == 1 && + !(eltype(plot_analytic_timeseries) <: AbstractArray) + push!(plot_vecs[j - 1], plot_analytic_timeseries) + else + push!(plot_vecs[j - 1], + u_n(plot_analytic_timeseries, x[j], sol, plott, + plot_analytic_timeseries)) + end + end + end + add_labels!(labels, x, dims, sol, strs) + end + end + + xflip --> sol.tdir < 0 + + if denseplot + seriestype --> :path + else + seriestype --> :scatter + end + + # Special case labels when idxs = (:x,:y,:z) or (:x) or [:x,:y] .. + if idxs isa Tuple && (typeof(idxs[1]) == Symbol && typeof(idxs[2]) == Symbol) + xlabel --> idxs[1] + ylabel --> idxs[2] + if length(idxs) > 2 + zlabel --> idxs[3] + end + end + if getindex.(int_vars, 1) == zeros(length(int_vars)) || + getindex.(int_vars, 2) == zeros(length(int_vars)) + xlabel --> "t" + end + + linewidth --> 3 + #xtickfont --> font(11) + #ytickfont --> font(11) + #legendfont --> font(11) + #guidefont --> font(11) + label --> reshape(labels, 1, length(labels)) + (plot_vecs...,) + + # Handle discrete variables + elseif !isempty(disc_vars) + int_vars = disc_vars + + if sol.tslocation != 0 + start_idx = sol.tslocation + else + start_idx = 1 + end + + if tspan === nothing + end_idx = length(sol.t) + else + _tspan = tspan isa Number ? (sol.t[1], tspan) : tspan + end_idx = findlast(x -> x <= _tspan[2], sol.t) + if end_idx === nothing + end_idx = length(sol.t) + end + end + + plott = sol.t[start_idx:end_idx] + plot_timeseries = sol.u[start_idx:end_idx] + + dims = length(int_vars[1]) + for var in int_vars + @assert length(var) == dims + end + + for i in 2:dims + push!(plot_vecs, []) + end + + labels = String[] + strs = String[] + varsyms = SciMLBase.variable_symbols(sol) + + for x in int_vars + for j in 2:dims + if x[j] == 0 + push!(plot_vecs[j - 1], plott) + elseif x[j] == 1 && !(eltype(plot_timeseries) <: AbstractArray) + push!(plot_vecs[j - 1], plot_timeseries) + else + if x[j] isa Integer + push!(plot_vecs[j - 1], [u[x[j]] for u in plot_timeseries]) + else + # Symbol indexing for discrete case + push!(plot_vecs[j - 1], [sol[ti, x[j]] for ti in 1:length(plott)]) + end + end + + if !isempty(varsyms) && x[j] isa Integer + push!(strs, String(getname(varsyms[x[j]]))) + elseif hasname(x[j]) + push!(strs, String(getname(x[j]))) + else + push!(strs, "u[$(x[j])]") + end + end + add_labels!(labels, x, dims, sol, strs) + end + + seriestype --> :steppost + if getindex.(int_vars, 1) == zeros(length(int_vars)) || + getindex.(int_vars, 2) == zeros(length(int_vars)) + xlabel --> "t" + end + + linewidth --> 3 + label --> reshape(labels, 1, length(labels)) + (plot_vecs...,) + end +end + +# Recipe for AbstractEnsembleSolution +@recipe function f(sim::AbstractEnsembleSolution; idxs = nothing, + summarize = true, error_style = :ribbon, ci_type = :quantile, linealpha = 0.4, zorder = 1) + + if idxs === nothing + if sim.u[1] isa SciMLBase.AbstractTimeseriesSolution + idxs = 1:length(sim.u[1].u[1]) + else + idxs = 1:length(sim.u[1]) + end + end + + if !(idxs isa Union{Tuple, AbstractArray}) + idxs = [idxs] + end + + if summarize + summary = EnsembleSummary(sim; quantiles = [0.05, 0.95]) + if error_style == :ribbon + ribbon --> (summary.qlow[:, idxs], summary.qhigh[:, idxs]) + elseif error_style == :bars + yerror --> (summary.qlow[:, idxs], summary.qhigh[:, idxs]) + end + summary.t, summary.med[:, idxs] + else + alpha --> linealpha + # Plot all trajectories + for i in eachindex(sim.u) + @series begin + if sim.u[i] isa SciMLBase.AbstractTimeseriesSolution + idxs --> idxs + sim.u[i] + else + # For non-timeseries solutions + sim.u[i][idxs] + end + end + end + end +end + +# Recipe for EnsembleSummary +@recipe function f(sim::EnsembleSummary; idxs = nothing, error_style = :ribbon) + if idxs === nothing + idxs = 1:size(sim.med, 2) + end + + if !(idxs isa Union{Tuple, AbstractArray}) + idxs = [idxs] + end + + if error_style == :ribbon + ribbon --> (sim.qlow[:, idxs], sim.qhigh[:, idxs]) + elseif error_style == :bars + yerror --> (sim.qlow[:, idxs], sim.qhigh[:, idxs]) + end + sim.t, sim.med[:, idxs] +end + +# Recipe for DEIntegrator +@recipe function f(integrator::DEIntegrator; + denseplot = (integrator.opts.calck || + integrator isa AbstractSDEIntegrator) && + integrator.iter > 0, + plotdensity = 10, + plot_analytic = false, vars = nothing, idxs = nothing) + if vars !== nothing + Base.depwarn( + "To maintain consistency with solution indexing, keyword argument vars will be removed in a future version. Please use keyword argument idxs instead.", + :f; force = true) + (idxs !== nothing) && + error("Simultaneously using keywords vars and idxs is not supported. Please only use idxs.") + idxs = vars + end + + int_vars = interpret_vars(idxs, integrator.sol) + + if denseplot + # Generate the points from the plot from dense function + plott = collect(range(integrator.tprev, integrator.t; length = plotdensity)) + if plot_analytic + plot_analytic_timeseries = [integrator.sol.prob.f.analytic( + integrator.sol.prob.u0, + integrator.sol.prob.p, + t) for t in plott] + end + else + plott = nothing + end + + dims = length(int_vars[1]) + for var in int_vars + @assert length(var) == dims + end + # Should check that all have the same dims! + + plot_vecs = [] + for i in 2:dims + push!(plot_vecs, []) + end + + labels = String[]# Array{String, 2}(1, length(int_vars)*(1+plot_analytic)) + strs = String[] + varsyms = SciMLBase.variable_symbols(integrator) + + for x in int_vars + for j in 2:dims + if denseplot + if (x[j] isa Integer && x[j] == 0) || + isequal(x[j], getindepsym_defaultt(integrator)) + push!(plot_vecs[j - 1], plott) + else + push!(plot_vecs[j - 1], Vector(integrator(plott; idxs = x[j]))) + end + else # just get values + if x[j] == 0 + push!(plot_vecs[j - 1], integrator.t) + elseif x[j] == 1 && !(integrator.u isa AbstractArray) + push!(plot_vecs[j - 1], integrator.u) + else + push!(plot_vecs[j - 1], integrator.u[x[j]]) + end + end + + if !isempty(varsyms) && x[j] isa Integer + push!(strs, String(getname(varsyms[x[j]]))) + elseif hasname(x[j]) + push!(strs, String(getname(x[j]))) + else + push!(strs, "u[$(x[j])]") + end + end + add_labels!(labels, x, dims, integrator.sol, strs) + end + + if plot_analytic + for x in int_vars + for j in 1:dims + if denseplot + push!(plot_vecs[j], + u_n(plot_timeseries, x[j], sol, plott, plot_timeseries)) + else # Just get values + if x[j] == 0 + push!(plot_vecs[j], integrator.t) + elseif x[j] == 1 && !(integrator.u isa AbstractArray) + push!(plot_vecs[j], + integrator.sol.prob.f(Val{:analytic}, integrator.t, + integrator.sol[1])) + else + push!(plot_vecs[j], + integrator.sol.prob.f(Val{:analytic}, integrator.t, + integrator.sol[1])[x[j]]) + end + end + end + add_labels!(labels, x, dims, integrator.sol, strs) + end + end + + xflip --> integrator.tdir < 0 + + if denseplot + seriestype --> :path + else + seriestype --> :scatter + end + + # Special case labels when idxs = (:x,:y,:z) or (:x) or [:x,:y] .. + if idxs isa Tuple && (typeof(idxs[1]) == Symbol && typeof(idxs[2]) == Symbol) + xlabel --> idxs[1] + ylabel --> idxs[2] + if length(idxs) > 2 + zlabel --> idxs[3] + end + end + if getindex.(int_vars, 1) == zeros(length(int_vars)) || + getindex.(int_vars, 2) == zeros(length(int_vars)) + xlabel --> "t" + end + + linewidth --> 3 + #xtickfont --> font(11) + #ytickfont --> font(11) + #legendfont --> font(11) + #guidefont --> font(11) + label --> reshape(labels, 1, length(labels)) + (plot_vecs...,) +end + +end \ No newline at end of file diff --git a/ext/SciMLBaseRuntimeGeneratedFunctionsExt.jl b/ext/SciMLBaseRuntimeGeneratedFunctionsExt.jl new file mode 100644 index 000000000..8e5eb1c56 --- /dev/null +++ b/ext/SciMLBaseRuntimeGeneratedFunctionsExt.jl @@ -0,0 +1,20 @@ +module SciMLBaseRuntimeGeneratedFunctionsExt + +using SciMLBase +using RuntimeGeneratedFunctions + +function SciMLBase.numargs(f::RuntimeGeneratedFunctions.RuntimeGeneratedFunction{ + T, + V, + W, + I +}) where { + T, + V, + W, + I +} + (length(T),) +end + +end \ No newline at end of file diff --git a/src/SciMLBase.jl b/src/SciMLBase.jl index 49d9e246d..06fb19069 100644 --- a/src/SciMLBase.jl +++ b/src/SciMLBase.jl @@ -4,7 +4,7 @@ if isdefined(Base, :Experimental) && @eval Base.Experimental.@max_methods 1 end using ConstructionBase -using RecipesBase, RecursiveArrayTools +using RecursiveArrayTools using SciMLStructures using SymbolicIndexingInterface using DocStringExtensions @@ -19,7 +19,6 @@ import Logging, ArrayInterface import IteratorInterfaceExtensions import CommonSolve: solve, init, step!, solve! import FunctionWrappersWrappers -import RuntimeGeneratedFunctions import EnumX import ADTypes: ADTypes, AbstractADType import Accessors: @set, @reset, @delete, @insert diff --git a/src/ensemble/ensemble_solutions.jl b/src/ensemble/ensemble_solutions.jl index ab25d1181..ffcb46ea7 100644 --- a/src/ensemble/ensemble_solutions.jl +++ b/src/ensemble/ensemble_solutions.jl @@ -184,76 +184,7 @@ end ### Plot Recipes -@recipe function f(sim::AbstractEnsembleSolution; - zcolors = sim.u isa AbstractArray ? fill(nothing, length(sim.u)) : - nothing, - trajectories = eachindex(sim)) - for i in trajectories - size(sim.u[i].u, 1) == 0 && continue - @series begin - legend := false - xlims --> (-Inf, Inf) - ylims --> (-Inf, Inf) - zlims --> (-Inf, Inf) - marker_z --> zcolors[i] - sim.u[i] - end - end -end -@recipe function f(sim::EnsembleSummary; - idxs = sim.u.u[1] isa AbstractArray ? eachindex(sim.u.u[1]) : - 1, - error_style = :ribbon, ci_type = :quantile) - if ci_type == :SEM - if sim.u.u[1] isa AbstractArray - u = vecarr_to_vectors(sim.u) - else - u = [sim.u.u] - end - if sim.u.u[1] isa AbstractArray - ci_low = vecarr_to_vectors(VectorOfArray([sqrt.(sim.v.u[i] / sim.num_monte) .* - 1.96 for i in 1:length(sim.v)])) - ci_high = ci_low - else - ci_low = [[sqrt(sim.v.u[i] / length(sim.num_monte)) .* 1.96 - for i in 1:length(sim.v)]] - ci_high = ci_low - end - elseif ci_type == :quantile - if sim.med.u[1] isa AbstractArray - u = vecarr_to_vectors(sim.med) - else - u = [sim.med.u] - end - if sim.u.u[1] isa AbstractArray - ci_low = u - vecarr_to_vectors(sim.qlow) - ci_high = vecarr_to_vectors(sim.qhigh) - u - else - ci_low = [u[1] - sim.qlow.u] - ci_high = [sim.qhigh.u - u[1]] - end - else - error("ci_type choice not valid. Must be `:SEM` or `:quantile`") - end - for i in idxs - @series begin - legend --> false - linewidth --> 3 - fillalpha --> 0.2 - if error_style == :ribbon - ribbon --> (ci_low[i], ci_high[i]) - elseif error_style == :bars - yerror --> (ci_low[i], ci_high[i]) - elseif error_style == :none - nothing - else - error("error_style not recognized") - end - sim.t, u[i] - end - end -end function (sol::AbstractEnsembleSolution)(args...; kwargs...) [s(args...; kwargs...) for s in sol] diff --git a/src/integrator_interface.jl b/src/integrator_interface.jl index 3b1376285..becaaf776 100644 --- a/src/integrator_interface.jl +++ b/src/integrator_interface.jl @@ -773,135 +773,6 @@ end Base.length(iter::TimeChoiceIterator) = length(iter.ts) -@recipe function f(integrator::DEIntegrator; - denseplot = (integrator.opts.calck || - integrator isa AbstractSDEIntegrator) && - integrator.iter > 0, - plotdensity = 10, - plot_analytic = false, vars = nothing, idxs = nothing) - if vars !== nothing - Base.depwarn( - "To maintain consistency with solution indexing, keyword argument vars will be removed in a future version. Please use keyword argument idxs instead.", - :f; force = true) - (idxs !== nothing) && - error("Simultaneously using keywords vars and idxs is not supported. Please only use idxs.") - idxs = vars - end - - int_vars = interpret_vars(idxs, integrator.sol) - - if denseplot - # Generate the points from the plot from dense function - plott = collect(range(integrator.tprev, integrator.t; length = plotdensity)) - if plot_analytic - plot_analytic_timeseries = [integrator.sol.prob.f.analytic( - integrator.sol.prob.u0, - integrator.sol.prob.p, - t) for t in plott] - end - else - plott = nothing - end - - dims = length(int_vars[1]) - for var in int_vars - @assert length(var) == dims - end - # Should check that all have the same dims! - - plot_vecs = [] - for i in 2:dims - push!(plot_vecs, []) - end - - labels = String[]# Array{String, 2}(1, length(int_vars)*(1+plot_analytic)) - strs = String[] - varsyms = variable_symbols(integrator) - @show plott - - for x in int_vars - for j in 2:dims - if denseplot - if (x[j] isa Integer && x[j] == 0) || - isequal(x[j], getindepsym_defaultt(integrator)) - push!(plot_vecs[j - 1], plott) - else - push!(plot_vecs[j - 1], Vector(integrator(plott; idxs = x[j]))) - end - else # just get values - if x[j] == 0 - push!(plot_vecs[j - 1], integrator.t) - elseif x[j] == 1 && !(integrator.u isa AbstractArray) - push!(plot_vecs[j - 1], integrator.u) - else - push!(plot_vecs[j - 1], integrator.u[x[j]]) - end - end - - if !isempty(varsyms) && x[j] isa Integer - push!(strs, String(getname(varsyms[x[j]]))) - elseif hasname(x[j]) - push!(strs, String(getname(x[j]))) - else - push!(strs, "u[$(x[j])]") - end - end - add_labels!(labels, x, dims, integrator.sol, strs) - end - - if plot_analytic - for x in int_vars - for j in 1:dims - if denseplot - push!(plot_vecs[j], - u_n(plot_timeseries, x[j], sol, plott, plot_timeseries)) - else # Just get values - if x[j] == 0 - push!(plot_vecs[j], integrator.t) - elseif x[j] == 1 && !(integrator.u isa AbstractArray) - push!(plot_vecs[j], - integrator.sol.prob.f(Val{:analytic}, integrator.t, - integrator.sol[1])) - else - push!(plot_vecs[j], - integrator.sol.prob.f(Val{:analytic}, integrator.t, - integrator.sol[1])[x[j]]) - end - end - end - add_labels!(labels, x, dims, integrator.sol, strs) - end - end - - xflip --> integrator.tdir < 0 - - if denseplot - seriestype --> :path - else - seriestype --> :scatter - end - - # Special case labels when idxs = (:x,:y,:z) or (:x) or [:x,:y] ... - if idxs isa Tuple && (typeof(idxs[1]) == Symbol && typeof(idxs[2]) == Symbol) - xlabel --> idxs[1] - ylabel --> idxs[2] - if length(idxs) > 2 - zlabel --> idxs[3] - end - end - if getindex.(int_vars, 1) == zeros(length(int_vars)) || - getindex.(int_vars, 2) == zeros(length(int_vars)) - xlabel --> "t" - end - - linewidth --> 3 - #xtickfont --> font(11) - #ytickfont --> font(11) - #legendfont --> font(11) - #guidefont --> font(11) - label --> reshape(labels, 1, length(labels)) - (plot_vecs...,) -end function step!(integ::DEIntegrator, dt, stop_at_tdt = false) (dt * integ.tdir) < 0 * oneunit(dt) && error("Cannot step backward.") diff --git a/src/solutions/solution_interface.jl b/src/solutions/solution_interface.jl index f08a73ce9..835ab7cba 100644 --- a/src/solutions/solution_interface.jl +++ b/src/solutions/solution_interface.jl @@ -212,174 +212,6 @@ used for plotting. plottable_indices(x::AbstractArray) = 1:length(x) plottable_indices(x::Number) = 1 -@recipe function f(sol::AbstractTimeseriesSolution; - plot_analytic = false, - denseplot = isdenseplot(sol), - plotdensity = min(Int(1e5), - sol.tslocation == 0 ? - (sol.prob isa AbstractDiscreteProblem ? - max(1000, 100 * length(sol)) : - max(1000, 10 * length(sol))) : - 1000 * sol.tslocation), plotat = nothing, - tspan = nothing, - vars = nothing, idxs = nothing) - if vars !== nothing - Base.depwarn( - "To maintain consistency with solution indexing, keyword argument vars will be removed in a future version. Please use keyword argument idxs instead.", - :f; force = true) - (idxs !== nothing) && - error("Simultaneously using keywords vars and idxs is not supported. Please only use idxs.") - idxs = vars - end - - if plot_analytic && (sol.u_analytic === nothing) - throw(ArgumentError("No analytic solution was found but `plot_analytic` was set to `true`.")) - end - - idxs = idxs === nothing ? plottable_indices(sol.u[1]) : idxs - if !(idxs isa Union{Tuple, AbstractArray}) - vars = interpret_vars([idxs], sol) - else - vars = interpret_vars(idxs, sol) - end - disc_vars = Tuple[] - cont_vars = Tuple[] - for var in vars - tsidxs = union(get_all_timeseries_indexes(sol, var[2]), - get_all_timeseries_indexes(sol, var[3])) - if ContinuousTimeseries() in tsidxs || isempty(tsidxs) - push!(cont_vars, var) - else - push!(disc_vars, (var..., only(tsidxs))) - end - end - idxs = identity.(cont_vars) - vars = identity.(cont_vars) - tdir = sign(sol.t[end] - sol.t[1]) - xflip --> tdir < 0 - seriestype --> :path - - @series begin - if idxs isa Union{AbstractArray, Tuple} && isempty(idxs) - label --> nothing - ([], []) - else - tscale = get(plotattributes, :xscale, :identity) - plot_vecs, - labels = diffeq_to_arrays(sol, plot_analytic, denseplot, - plotdensity, tspan, vars, tscale, plotat) - - # Special case labels when idxs = (:x,:y,:z) or (:x) or [:x,:y] ... - if idxs isa Tuple && vars[1][1] === DEFAULT_PLOT_FUNC - val = hasname(vars[1][2]) ? String(getname(vars[1][2])) : vars[1][2] - if val isa Integer - if val == 0 - val = "t" - else - val = "u[$val]" - end - end - xguide --> val - val = hasname(vars[1][3]) ? String(getname(vars[1][3])) : vars[1][3] - if val isa Integer - if val == 0 - val = "t" - else - val = "u[$val]" - end - end - yguide --> val - if length(idxs) > 2 - val = hasname(vars[1][4]) ? String(getname(vars[1][4])) : vars[1][4] - if val isa Integer - if val == 0 - val = "t" - else - val = "u[$val]" - end - end - zguide --> val - end - end - - if (!any(!isequal(NotSymbolic()), symbolic_type.(getindex.(vars, 1))) && - getindex.(vars, 1) == zeros(length(vars))) || - (!any(!isequal(NotSymbolic()), symbolic_type.(getindex.(vars, 2))) && - getindex.(vars, 2) == zeros(length(vars))) || - all(t -> Symbol(t) == getindepsym_defaultt(sol), getindex.(vars, 1)) || - all(t -> Symbol(t) == getindepsym_defaultt(sol), getindex.(vars, 2)) - xguide --> "$(getindepsym_defaultt(sol))" - end - if length(vars[1]) >= 3 && - ((!any(!isequal(NotSymbolic()), symbolic_type.(getindex.(vars, 3))) && - getindex.(vars, 3) == zeros(length(vars))) || - all(t -> Symbol(t) == getindepsym_defaultt(sol), getindex.(vars, 3))) - yguide --> "$(getindepsym_defaultt(sol))" - end - if length(vars[1]) >= 4 && - ((!any(!isequal(NotSymbolic()), symbolic_type.(getindex.(vars, 4))) && - getindex.(vars, 4) == zeros(length(vars))) || - all(t -> Symbol(t) == getindepsym_defaultt(sol), getindex.(vars, 4))) - zguide --> "$(getindepsym_defaultt(sol))" - end - - if (!any(!isequal(NotSymbolic()), symbolic_type.(getindex.(vars, 2))) && - getindex.(vars, 2) == zeros(length(vars))) || - all(t -> Symbol(t) == getindepsym_defaultt(sol), getindex.(vars, 2)) - if tspan === nothing - if tdir > 0 - xlims --> (sol.t[1], sol.t[end]) - else - xlims --> (sol.t[end], sol.t[1]) - end - else - xlims --> (tspan[1], tspan[end]) - end - end - - label --> reshape(labels, 1, length(labels)) - (plot_vecs...,) - end - end - for (func, xvar, yvar, tsidx) in disc_vars - partition = sol.discretes[tsidx] - ts = current_time(partition) - if tspan !== nothing - tstart = searchsortedfirst(ts, tspan[1]) - tend = searchsortedlast(ts, tspan[2]) - if tstart == lastindex(ts) + 1 || tend == firstindex(ts) - 1 - continue - end - else - tstart = firstindex(ts) - tend = lastindex(ts) - end - ts = ts[tstart:tend] - - if symbolic_type(xvar) == NotSymbolic() && xvar == 0 - xvar = only(independent_variable_symbols(sol)) - end - xvals = sol(ts; idxs = xvar).u - # xvals = getsym(sol, xvar)(sol, tstart:tend) - yvals = getp(sol, yvar)(sol, tstart:tend) - tmpvals = map(func, xvals, yvals) - xvals = getindex.(tmpvals, 1) - yvals = getindex.(tmpvals, 2) - # Scatterplot of points - @series begin - seriestype := :line - linestyle --> :dash - markershape --> :o - markersize --> repeat([2, 0], length(ts) - 1) - markeralpha --> repeat([1, 0], length(ts) - 1) - label --> string(hasname(yvar) ? getname(yvar) : yvar) - - x = vec([xvals[1:(end - 1)]'; xvals[2:end]']) - y = repeat(yvals, inner = 2)[1:(end - 1)] - x, y - end - end -end function diffeq_to_arrays(sol, plot_analytic, denseplot, plotdensity, tspan, vars, tscale, plotat) diff --git a/src/utils.jl b/src/utils.jl index ecded5af1..35b892706 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -13,19 +13,6 @@ function numargs(f) end end -function numargs(f::RuntimeGeneratedFunctions.RuntimeGeneratedFunction{ - T, - V, - W, - I -}) where { - T, - V, - W, - I -} - (length(T),) -end numargs(f::ComposedFunction) = numargs(f.inner)