diff --git a/ext/SciMLBaseMakieExt.jl b/ext/SciMLBaseMakieExt.jl index 0071ee47c..e67bc269a 100644 --- a/ext/SciMLBaseMakieExt.jl +++ b/ext/SciMLBaseMakieExt.jl @@ -57,6 +57,10 @@ function Makie.convert_arguments(PT::Type{<:Plot}, # Makie error message for convert_arguments - just at a different place. # TODO: this is a bit of a hack, but of course one can define specific dispatches elsewhere... ensure_plottrait(PT, sol, Makie.PointBased) + + # Helper function to determine plottable indices + plottable_indices(x::AbstractArray) = 1:length(x) + plottable_indices(x::Number) = 1 if vars !== nothing Base.depwarn( @@ -69,22 +73,36 @@ function Makie.convert_arguments(PT::Type{<:Plot}, # Extract indices (this is SOP) - idxs = idxs === nothing ? (1:length(sol.u[1])) : idxs + idxs = idxs === nothing ? plottable_indices(sol.u[1]) : idxs + + # Check for analytic solution + if plot_analytic && (sol.u_analytic === nothing) + throw(ArgumentError("No analytic solution was found but `plot_analytic` was set to `true`.")) + end if !(idxs isa Union{Tuple, AbstractArray}) vars = SciMLBase.interpret_vars([idxs], sol) else vars = SciMLBase.interpret_vars(idxs, sol) end + + # Separate continuous and discrete variables + disc_vars = Tuple[] + cont_vars = Tuple[] + for var in vars + tsidxs = union(SciMLBase.get_all_timeseries_indexes(sol, var[2]), + SciMLBase.get_all_timeseries_indexes(sol, var[3])) + if SciMLBase.ContinuousTimeseries() in tsidxs || isempty(tsidxs) + push!(cont_vars, var) + else + push!(disc_vars, (var..., only(tsidxs))) + end + end # Translate automatics inside the function, for ease of use + passthrough from higher # level recipes if denseplot isa Makie.Automatic - denseplot = (sol.dense || - typeof(sol.prob) <: SciMLBase.AbstractDiscreteProblem) && - !(typeof(sol) <: SciMLBase.AbstractRODESolution) && - !(hasfield(typeof(sol), :interp) && - typeof(sol.interp) <: SciMLBase.SensitivityInterpolation) + denseplot = SciMLBase.isdenseplot(sol) end if plotdensity isa Makie.Automatic @@ -107,7 +125,7 @@ function Makie.convert_arguments(PT::Type{<:Plot}, # Convert the solution to arrays - this is the hard part! plot_vecs, labels = SciMLBase.diffeq_to_arrays(sol, plot_analytic, denseplot, - plotdensity, tspan, vars, tscale, plotat) + plotdensity, tspan, cont_vars, tscale, plotat) # We must convert from plot Type to symbol here, for plotspec use # since PlotSpecs are defined based on symbols @@ -116,17 +134,89 @@ function Makie.convert_arguments(PT::Type{<:Plot}, # Finally, generate a vector of PlotSpecs (one per variable pair) # TODO: broadcast across all input attributes, or figure out how to # allow customizable colors/labels/etc if required - makie_plotspecs = if length(plot_vecs) == 2 - map((x, y, label) -> PlotSpec(plot_type_sym, Point2f.(x, y); label), - eachcol(plot_vecs[1]), - eachcol(plot_vecs[2]), - labels) + makie_plotspecs = if isempty(cont_vars) || (isempty(labels) && isempty(plot_vecs)) + # No continuous variables, start with empty vector + PlotSpec[] + elseif length(plot_vecs) == 0 + PlotSpec[] + elseif length(plot_vecs) == 2 + # Count how many are numerical vs analytical solutions + n_plots = size(plot_vecs[1], 2) + n_numeric = plot_analytic ? n_plots ÷ 2 : n_plots + + plots = PlotSpec[] + for i in 1:n_plots + # Use Cycled color for numeric solutions, different style for analytic + if plot_analytic && i > n_numeric + # Analytic solution - use dashed line + push!(plots, PlotSpec(plot_type_sym, Point2f.(plot_vecs[1][:, i], plot_vecs[2][:, i]); + label=labels[i], linestyle=:dash, color=Makie.Cycled(i - n_numeric))) + else + push!(plots, PlotSpec(plot_type_sym, Point2f.(plot_vecs[1][:, i], plot_vecs[2][:, i]); + label=labels[i], color=Makie.Cycled(i))) + end + end + plots elseif length(plot_vecs) == 3 - map((x, y, z, label) -> PlotSpec(plot_type_sym, Point3f.(x, y, z); label), - eachcol(plot_vecs[1]), - eachcol(plot_vecs[2]), - eachcol(plot_vecs[3]), - labels) + n_plots = size(plot_vecs[1], 2) + n_numeric = plot_analytic ? n_plots ÷ 2 : n_plots + + plots = PlotSpec[] + for i in 1:n_plots + if plot_analytic && i > n_numeric + push!(plots, PlotSpec(plot_type_sym, Point3f.(plot_vecs[1][:, i], plot_vecs[2][:, i], plot_vecs[3][:, i]); + label=labels[i], linestyle=:dash, color=Makie.Cycled(i - n_numeric))) + else + push!(plots, PlotSpec(plot_type_sym, Point3f.(plot_vecs[1][:, i], plot_vecs[2][:, i], plot_vecs[3][:, i]); + label=labels[i], color=Makie.Cycled(i))) + end + end + plots + else + PlotSpec[] + end + + # Add discrete variable plots + if hasfield(typeof(sol), :discretes) && !isempty(disc_vars) + for (func, xvar, yvar, tsidx) in disc_vars + partition = sol.discretes[tsidx] + ts = SciMLBase.current_time(partition) + + # Apply tspan filtering + if tspan !== nothing + tstart = searchsortedfirst(ts, tspan[1]) + tend = searchsortedlast(ts, tspan[2]) + if tstart == lastindex(ts) + 1 || tend == firstindex(ts) - 1 + continue + end + else + tstart = firstindex(ts) + tend = lastindex(ts) + end + ts = ts[tstart:tend] + + # Get values + if SciMLBase.symbolic_type(xvar) == SciMLBase.NotSymbolic() && xvar == 0 + xvar = only(SciMLBase.independent_variable_symbols(sol)) + end + xvals = sol(ts; idxs = xvar).u + yvals = SciMLBase.getp(sol, yvar)(sol, tstart:tend) + tmpvals = map(func, xvals, yvals) + xvals = getindex.(tmpvals, 1) + yvals = getindex.(tmpvals, 2) + + # Create stepped line visualization + x = vec([xvals[1:(end - 1)]'; xvals[2:end]']) + y = repeat(yvals, inner = 2)[1:(end - 1)] + + push!(makie_plotspecs, S.Lines(Point2f.(x, y); + linestyle = :dash, + marker = :circle, + markersize = repeat([4, 0], length(ts) - 1), + markeralpha = repeat([1, 0], length(ts) - 1), + label = string(SciMLBase.hasname(yvar) ? SciMLBase.getname(yvar) : yvar) + )) + end end return makie_plotspecs