Skip to content

Commit f769206

Browse files
feat: add discrete variables support to solution plot recipe
1 parent 20efc75 commit f769206

File tree

1 file changed

+112
-65
lines changed

1 file changed

+112
-65
lines changed

src/solutions/solution_interface.jl

Lines changed: 112 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -176,91 +176,138 @@ DEFAULT_PLOT_FUNC(x, y, z) = (x, y, z) # For v0.5.2 bug
176176
end
177177

178178
idxs = idxs === nothing ? (1:length(sol.u[1])) : idxs
179-
179+
disc_idxs = []
180+
cont_idxs = []
181+
for idx in idxs
182+
tsidxs = get_all_timeseries_indexes(sol, idx)
183+
if ContinuousTimeseries() in tsidxs
184+
push!(cont_idxs, idx)
185+
else
186+
push!(disc_idxs, (idx, only(tsidxs)))
187+
end
188+
end
189+
idxs = identity.(cont_idxs)
180190
if !(idxs isa Union{Tuple, AbstractArray})
181191
vars = interpret_vars([idxs], sol)
182192
else
183193
vars = interpret_vars(idxs, sol)
184194
end
185-
186-
tscale = get(plotattributes, :xscale, :identity)
187-
plot_vecs, labels = diffeq_to_arrays(sol, plot_analytic, denseplot,
188-
plotdensity, tspan, vars, tscale, plotat)
189-
190195
tdir = sign(sol.t[end] - sol.t[1])
191196
xflip --> tdir < 0
192197
seriestype --> :path
193198

194-
# Special case labels when idxs = (:x,:y,:z) or (:x) or [:x,:y] ...
195-
if idxs isa Tuple && vars[1][1] === DEFAULT_PLOT_FUNC
196-
val = hasname(vars[1][2]) ? String(getname(vars[1][2])) : vars[1][2]
197-
if val isa Integer
198-
if val == 0
199-
val = "t"
200-
else
201-
val = "u[$val]"
199+
@series begin
200+
if isempty(idxs)
201+
label --> nothing
202+
([], [])
203+
else
204+
tscale = get(plotattributes, :xscale, :identity)
205+
plot_vecs, labels = diffeq_to_arrays(sol, plot_analytic, denseplot,
206+
plotdensity, tspan, vars, tscale, plotat)
207+
208+
209+
# Special case labels when idxs = (:x,:y,:z) or (:x) or [:x,:y] ...
210+
if idxs isa Tuple && vars[1][1] === DEFAULT_PLOT_FUNC
211+
val = hasname(vars[1][2]) ? String(getname(vars[1][2])) : vars[1][2]
212+
if val isa Integer
213+
if val == 0
214+
val = "t"
215+
else
216+
val = "u[$val]"
217+
end
218+
end
219+
xguide --> val
220+
val = hasname(vars[1][3]) ? String(getname(vars[1][3])) : vars[1][3]
221+
if val isa Integer
222+
if val == 0
223+
val = "t"
224+
else
225+
val = "u[$val]"
226+
end
227+
end
228+
yguide --> val
229+
if length(idxs) > 2
230+
val = hasname(vars[1][4]) ? String(getname(vars[1][4])) : vars[1][4]
231+
if val isa Integer
232+
if val == 0
233+
val = "t"
234+
else
235+
val = "u[$val]"
236+
end
237+
end
238+
zguide --> val
239+
end
202240
end
203-
end
204-
xguide --> val
205-
val = hasname(vars[1][3]) ? String(getname(vars[1][3])) : vars[1][3]
206-
if val isa Integer
207-
if val == 0
208-
val = "t"
209-
else
210-
val = "u[$val]"
241+
242+
if (!any(!isequal(NotSymbolic()), symbolic_type.(getindex.(vars, 1))) &&
243+
getindex.(vars, 1) == zeros(length(vars))) ||
244+
(!any(!isequal(NotSymbolic()), symbolic_type.(getindex.(vars, 2))) &&
245+
getindex.(vars, 2) == zeros(length(vars))) ||
246+
all(t -> Symbol(t) == getindepsym_defaultt(sol), getindex.(vars, 1)) ||
247+
all(t -> Symbol(t) == getindepsym_defaultt(sol), getindex.(vars, 2))
248+
xguide --> "$(getindepsym_defaultt(sol))"
211249
end
212-
end
213-
yguide --> val
214-
if length(idxs) > 2
215-
val = hasname(vars[1][4]) ? String(getname(vars[1][4])) : vars[1][4]
216-
if val isa Integer
217-
if val == 0
218-
val = "t"
250+
if length(vars[1]) >= 3 &&
251+
((!any(!isequal(NotSymbolic()), symbolic_type.(getindex.(vars, 3))) &&
252+
getindex.(vars, 3) == zeros(length(vars))) ||
253+
all(t -> Symbol(t) == getindepsym_defaultt(sol), getindex.(vars, 3)))
254+
yguide --> "$(getindepsym_defaultt(sol))"
255+
end
256+
if length(vars[1]) >= 4 &&
257+
((!any(!isequal(NotSymbolic()), symbolic_type.(getindex.(vars, 4))) &&
258+
getindex.(vars, 4) == zeros(length(vars))) ||
259+
all(t -> Symbol(t) == getindepsym_defaultt(sol), getindex.(vars, 4)))
260+
zguide --> "$(getindepsym_defaultt(sol))"
261+
end
262+
263+
if (!any(!isequal(NotSymbolic()), symbolic_type.(getindex.(vars, 2))) &&
264+
getindex.(vars, 2) == zeros(length(vars))) ||
265+
all(t -> Symbol(t) == getindepsym_defaultt(sol), getindex.(vars, 2))
266+
if tspan === nothing
267+
if tdir > 0
268+
xlims --> (sol.t[1], sol.t[end])
269+
else
270+
xlims --> (sol.t[end], sol.t[1])
271+
end
219272
else
220-
val = "u[$val]"
273+
xlims --> (tspan[1], tspan[end])
221274
end
222275
end
223-
zguide --> val
224-
end
225-
end
226276

227-
if (!any(!isequal(NotSymbolic()), symbolic_type.(getindex.(vars, 1))) &&
228-
getindex.(vars, 1) == zeros(length(vars))) ||
229-
(!any(!isequal(NotSymbolic()), symbolic_type.(getindex.(vars, 2))) &&
230-
getindex.(vars, 2) == zeros(length(vars))) ||
231-
all(t -> Symbol(t) == getindepsym_defaultt(sol), getindex.(vars, 1)) ||
232-
all(t -> Symbol(t) == getindepsym_defaultt(sol), getindex.(vars, 2))
233-
xguide --> "$(getindepsym_defaultt(sol))"
234-
end
235-
if length(vars[1]) >= 3 &&
236-
((!any(!isequal(NotSymbolic()), symbolic_type.(getindex.(vars, 3))) &&
237-
getindex.(vars, 3) == zeros(length(vars))) ||
238-
all(t -> Symbol(t) == getindepsym_defaultt(sol), getindex.(vars, 3)))
239-
yguide --> "$(getindepsym_defaultt(sol))"
240-
end
241-
if length(vars[1]) >= 4 &&
242-
((!any(!isequal(NotSymbolic()), symbolic_type.(getindex.(vars, 4))) &&
243-
getindex.(vars, 4) == zeros(length(vars))) ||
244-
all(t -> Symbol(t) == getindepsym_defaultt(sol), getindex.(vars, 4)))
245-
zguide --> "$(getindepsym_defaultt(sol))"
277+
label --> reshape(labels, 1, length(labels))
278+
(plot_vecs...,)
279+
end
246280
end
247-
248-
if (!any(!isequal(NotSymbolic()), symbolic_type.(getindex.(vars, 2))) &&
249-
getindex.(vars, 2) == zeros(length(vars))) ||
250-
all(t -> Symbol(t) == getindepsym_defaultt(sol), getindex.(vars, 2))
251-
if tspan === nothing
252-
if tdir > 0
253-
xlims --> (sol.t[1], sol.t[end])
254-
else
255-
xlims --> (sol.t[end], sol.t[1])
281+
for (idx, tsidx) in disc_idxs
282+
partition = sol.discretes[tsidx]
283+
ts = current_time(partition)
284+
if tspan !== nothing
285+
tstart = searchsortedfirst(ts, tspan[1])
286+
tend = searchsortedlast(ts, tspan[2])
287+
if tstart == lastindex(ts) + 1 || tend == firstindex(ts) - 1
288+
continue
256289
end
257290
else
258-
xlims --> (tspan[1], tspan[end])
291+
tstart = firstindex(ts)
292+
tend = lastindex(ts)
293+
end
294+
ts = ts[tstart:tend]
295+
296+
vals = getp(sol, idx)(sol, tstart:tend)
297+
# Scatterplot of points
298+
@series begin
299+
seriestype := :line
300+
linestyle --> :dash
301+
markershape --> :o
302+
markersize --> repeat([2, 0], length(ts)-1)
303+
markeralpha --> repeat([1, 0], length(ts)-1)
304+
label --> string(hasname(idx) ? getname(idx) : idx)
305+
306+
x = vec([ts[1:end-1]'; ts[2:end]'])
307+
y = repeat(vals, inner=2)[1:end-1]
308+
x, y
259309
end
260310
end
261-
262-
label --> reshape(labels, 1, length(labels))
263-
(plot_vecs...,)
264311
end
265312

266313
function diffeq_to_arrays(sol, plot_analytic, denseplot, plotdensity, tspan,

0 commit comments

Comments
 (0)