Skip to content

Port Plots.jl recipe features to Makie recipe for AbstractTimeseriesSolution #1064

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 107 additions & 17 deletions ext/SciMLBaseMakieExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ function Makie.convert_arguments(PT::Type{<:Plot},
# Makie error message for convert_arguments - just at a different place.
# TODO: this is a bit of a hack, but of course one can define specific dispatches elsewhere...
ensure_plottrait(PT, sol, Makie.PointBased)

# Helper function to determine plottable indices
plottable_indices(x::AbstractArray) = 1:length(x)
plottable_indices(x::Number) = 1

if vars !== nothing
Base.depwarn(
Expand All @@ -69,22 +73,36 @@ function Makie.convert_arguments(PT::Type{<:Plot},

# Extract indices (this is SOP)

idxs = idxs === nothing ? (1:length(sol.u[1])) : idxs
idxs = idxs === nothing ? plottable_indices(sol.u[1]) : idxs

# Check for analytic solution
if plot_analytic && (sol.u_analytic === nothing)
throw(ArgumentError("No analytic solution was found but `plot_analytic` was set to `true`."))
end

if !(idxs isa Union{Tuple, AbstractArray})
vars = SciMLBase.interpret_vars([idxs], sol)
else
vars = SciMLBase.interpret_vars(idxs, sol)
end

# Separate continuous and discrete variables
disc_vars = Tuple[]
cont_vars = Tuple[]
for var in vars
tsidxs = union(SciMLBase.get_all_timeseries_indexes(sol, var[2]),
SciMLBase.get_all_timeseries_indexes(sol, var[3]))
if SciMLBase.ContinuousTimeseries() in tsidxs || isempty(tsidxs)
push!(cont_vars, var)
else
push!(disc_vars, (var..., only(tsidxs)))
end
end

# Translate automatics inside the function, for ease of use + passthrough from higher
# level recipes
if denseplot isa Makie.Automatic
denseplot = (sol.dense ||
typeof(sol.prob) <: SciMLBase.AbstractDiscreteProblem) &&
!(typeof(sol) <: SciMLBase.AbstractRODESolution) &&
!(hasfield(typeof(sol), :interp) &&
typeof(sol.interp) <: SciMLBase.SensitivityInterpolation)
denseplot = SciMLBase.isdenseplot(sol)
end

if plotdensity isa Makie.Automatic
Expand All @@ -107,7 +125,7 @@ function Makie.convert_arguments(PT::Type{<:Plot},
# Convert the solution to arrays - this is the hard part!
plot_vecs,
labels = SciMLBase.diffeq_to_arrays(sol, plot_analytic, denseplot,
plotdensity, tspan, vars, tscale, plotat)
plotdensity, tspan, cont_vars, tscale, plotat)

# We must convert from plot Type to symbol here, for plotspec use
# since PlotSpecs are defined based on symbols
Expand All @@ -116,17 +134,89 @@ function Makie.convert_arguments(PT::Type{<:Plot},
# Finally, generate a vector of PlotSpecs (one per variable pair)
# TODO: broadcast across all input attributes, or figure out how to
# allow customizable colors/labels/etc if required
makie_plotspecs = if length(plot_vecs) == 2
map((x, y, label) -> PlotSpec(plot_type_sym, Point2f.(x, y); label),
eachcol(plot_vecs[1]),
eachcol(plot_vecs[2]),
labels)
makie_plotspecs = if isempty(cont_vars) || (isempty(labels) && isempty(plot_vecs))
# No continuous variables, start with empty vector
PlotSpec[]
elseif length(plot_vecs) == 0
PlotSpec[]
elseif length(plot_vecs) == 2
# Count how many are numerical vs analytical solutions
n_plots = size(plot_vecs[1], 2)
n_numeric = plot_analytic ? n_plots ÷ 2 : n_plots

plots = PlotSpec[]
for i in 1:n_plots
# Use Cycled color for numeric solutions, different style for analytic
if plot_analytic && i > n_numeric
# Analytic solution - use dashed line
push!(plots, PlotSpec(plot_type_sym, Point2f.(plot_vecs[1][:, i], plot_vecs[2][:, i]);
label=labels[i], linestyle=:dash, color=Makie.Cycled(i - n_numeric)))
else
push!(plots, PlotSpec(plot_type_sym, Point2f.(plot_vecs[1][:, i], plot_vecs[2][:, i]);
label=labels[i], color=Makie.Cycled(i)))
end
end
plots
elseif length(plot_vecs) == 3
map((x, y, z, label) -> PlotSpec(plot_type_sym, Point3f.(x, y, z); label),
eachcol(plot_vecs[1]),
eachcol(plot_vecs[2]),
eachcol(plot_vecs[3]),
labels)
n_plots = size(plot_vecs[1], 2)
n_numeric = plot_analytic ? n_plots ÷ 2 : n_plots

plots = PlotSpec[]
for i in 1:n_plots
if plot_analytic && i > n_numeric
push!(plots, PlotSpec(plot_type_sym, Point3f.(plot_vecs[1][:, i], plot_vecs[2][:, i], plot_vecs[3][:, i]);
label=labels[i], linestyle=:dash, color=Makie.Cycled(i - n_numeric)))
else
push!(plots, PlotSpec(plot_type_sym, Point3f.(plot_vecs[1][:, i], plot_vecs[2][:, i], plot_vecs[3][:, i]);
label=labels[i], color=Makie.Cycled(i)))
end
end
plots
else
PlotSpec[]
end

# Add discrete variable plots
if hasfield(typeof(sol), :discretes) && !isempty(disc_vars)
for (func, xvar, yvar, tsidx) in disc_vars
partition = sol.discretes[tsidx]
ts = SciMLBase.current_time(partition)

# Apply tspan filtering
if tspan !== nothing
tstart = searchsortedfirst(ts, tspan[1])
tend = searchsortedlast(ts, tspan[2])
if tstart == lastindex(ts) + 1 || tend == firstindex(ts) - 1
continue
end
else
tstart = firstindex(ts)
tend = lastindex(ts)
end
ts = ts[tstart:tend]

# Get values
if SciMLBase.symbolic_type(xvar) == SciMLBase.NotSymbolic() && xvar == 0
xvar = only(SciMLBase.independent_variable_symbols(sol))
end
xvals = sol(ts; idxs = xvar).u
yvals = SciMLBase.getp(sol, yvar)(sol, tstart:tend)
tmpvals = map(func, xvals, yvals)
xvals = getindex.(tmpvals, 1)
yvals = getindex.(tmpvals, 2)

# Create stepped line visualization
x = vec([xvals[1:(end - 1)]'; xvals[2:end]'])
y = repeat(yvals, inner = 2)[1:(end - 1)]

push!(makie_plotspecs, S.Lines(Point2f.(x, y);
linestyle = :dash,
marker = :circle,
markersize = repeat([4, 0], length(ts) - 1),
markeralpha = repeat([1, 0], length(ts) - 1),
label = string(SciMLBase.hasname(yvar) ? SciMLBase.getname(yvar) : yvar)
))
end
end

return makie_plotspecs
Expand Down
Loading