Skip to content

Commit 8094a37

Browse files
Improve plotting for ModelingToolkit changes
- automatically convert \_+ to non-Unicode to make it easier for backends to support module syntax - check for integers first, then convert anything else to a symbol. This makes the MTK variable automatically supported.
1 parent 3b36d39 commit 8094a37

File tree

1 file changed

+35
-26
lines changed

1 file changed

+35
-26
lines changed

src/solutions/solution_interface.jl

Lines changed: 35 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -71,21 +71,23 @@ DEFAULT_PLOT_FUNC(x,y,z) = (x,y,z) # For v0.5.2 bug
7171

7272
syms = getsyms(sol)
7373
int_vars = interpret_vars(vars,sol,syms)
74+
strs = cleansyms(syms)
75+
7476
tscale = get(plotattributes, :xscale, :identity)
7577
plot_vecs,labels = diffeq_to_arrays(sol,plot_analytic,denseplot,
7678
plotdensity,tspan,axis_safety,
77-
vars,int_vars,tscale,syms)
79+
vars,int_vars,tscale,strs)
7880

7981
tdir = sign(sol.t[end]-sol.t[1])
8082
xflip --> tdir < 0
8183
seriestype --> :path
8284

8385
# Special case labels when vars = (:x,:y,:z) or (:x) or [:x,:y] ...
8486
if typeof(vars) <: Tuple && (typeof(vars[1]) == Symbol && typeof(vars[2]) == Symbol)
85-
xguide --> vars[1]
86-
yguide --> vars[2]
87+
xguide --> strs[int_vars[1][2]]
88+
yguide --> strs[int_vars[1][3]]
8789
if length(vars) > 2
88-
zguide --> vars[3]
90+
zguide --> strs[int_vars[1][4]]
8991
end
9092
end
9193
if getindex.(int_vars,1) == zeros(length(int_vars)) || getindex.(int_vars,2) == zeros(length(int_vars))
@@ -157,7 +159,14 @@ function getsyms(sol)
157159
end
158160
end
159161

160-
function diffeq_to_arrays(sol,plot_analytic,denseplot,plotdensity,tspan,axis_safety,vars,int_vars,tscale,syms)
162+
cleansyms(syms::Nothing) = nothing
163+
cleansyms(syms::Vector{Symbol}) = cleansym.(syms)
164+
function cleansym(sym::Symbol)
165+
str = String(sym)
166+
replace(str,""=>".") # Fix MTK component syntax
167+
end
168+
169+
function diffeq_to_arrays(sol,plot_analytic,denseplot,plotdensity,tspan,axis_safety,vars,int_vars,tscale,strs)
161170
if tspan === nothing
162171
if sol.tslocation == 0
163172
end_idx = length(sol)
@@ -228,32 +237,32 @@ function diffeq_to_arrays(sol,plot_analytic,denseplot,plotdensity,tspan,axis_saf
228237
@assert length(var) == dims
229238
end
230239
# Should check that all have the same dims!
231-
plot_vecs,labels = solplot_vecs_and_labels(dims,int_vars,plot_timeseries,plott,sol,plot_analytic,plot_analytic_timeseries,syms)
240+
plot_vecs,labels = solplot_vecs_and_labels(dims,int_vars,plot_timeseries,plott,sol,plot_analytic,plot_analytic_timeseries,strs)
232241
end
233242

234243
function interpret_vars(vars,sol,syms)
235244
if vars !== nothing && syms !== nothing
236245
# Do syms conversion
237246
tmp_vars = []
238247
for var in vars
239-
if typeof(var) <: Symbol
240-
var_int = something(findfirst(isequal(var), syms), 0)
241-
elseif typeof(var) <: Union{Tuple,AbstractArray} #eltype(var) <: Symbol # Some kind of iterable
248+
if typeof(var) <: Union{Tuple,AbstractArray} #eltype(var) <: Symbol # Some kind of iterable
242249
tmp = []
243250
for x in var
244-
if typeof(x) <: Symbol
245-
push!(tmp,something(findfirst(isequal(x), syms), 0))
246-
else
251+
if typeof(x) <: Integer
247252
push!(tmp,x)
253+
else
254+
push!(tmp,something(findfirst(isequal(Symbol(x)), syms), 0))
248255
end
249256
end
250257
if typeof(var) <: Tuple
251258
var_int = tuple(tmp...)
252259
else
253260
var_int = tmp
254261
end
255-
else
262+
elseif typeof(var) <: Integer
256263
var_int = var
264+
else
265+
var_int = something(findfirst(isequal(Symbol(var)), syms), 0)
257266
end
258267
push!(tmp_vars,var_int)
259268
end
@@ -325,14 +334,14 @@ function interpret_vars(vars,sol,syms)
325334
vars
326335
end
327336

328-
function add_labels!(labels,x,dims,sol,syms)
337+
function add_labels!(labels,x,dims,sol,strs)
329338
lys = []
330339
for j in 3:dims
331340
if x[j] == 0
332341
push!(lys,"t,")
333342
else
334-
if syms !== nothing
335-
push!(lys,"$(syms[x[j]]),")
343+
if strs !== nothing
344+
push!(lys,"$(strs[x[j]]),")
336345
else
337346
push!(lys,"u$(x[j]),")
338347
end
@@ -342,8 +351,8 @@ function add_labels!(labels,x,dims,sol,syms)
342351
if x[2] == 0 && dims == 3
343352
tmp_lab = "$(lys...)(t)"
344353
else
345-
if syms !== nothing && x[2] != 0
346-
tmp = syms[x[2]]
354+
if strs !== nothing && x[2] != 0
355+
tmp = strs[x[2]]
347356
tmp_lab = "($tmp,$(lys...))"
348357
else
349358
if x[2] == 0
@@ -361,14 +370,14 @@ function add_labels!(labels,x,dims,sol,syms)
361370
labels
362371
end
363372

364-
function add_analytic_labels!(labels,x,dims,sol,syms)
373+
function add_analytic_labels!(labels,x,dims,sol,strs)
365374
lys = []
366375
for j in 3:dims
367376
if x[j] == 0 && dims == 3
368377
push!(lys,"t,")
369378
else
370-
if syms !== nothing
371-
push!(lys,string("True ",syms[x[j]],","))
379+
if strs !== nothing
380+
push!(lys,string("True ",strs[x[j]],","))
372381
else
373382
push!(lys,"True u$(x[j]),")
374383
end
@@ -378,8 +387,8 @@ function add_analytic_labels!(labels,x,dims,sol,syms)
378387
if x[2] == 0
379388
tmp_lab = "$(lys...)(t)"
380389
else
381-
if syms !== nothing
382-
tmp = string("True ",syms[x[2]])
390+
if strs !== nothing
391+
tmp = string("True ",strs[x[2]])
383392
tmp_lab = "($tmp,$(lys...))"
384393
else
385394
tmp_lab = "(True u$(x[2]),$(lys...))"
@@ -407,7 +416,7 @@ function u_n(timeseries::AbstractArray, n::Int,sol,plott,plot_timeseries)
407416
end
408417
end
409418

410-
function solplot_vecs_and_labels(dims,vars,plot_timeseries,plott,sol,plot_analytic,plot_analytic_timeseries,syms)
419+
function solplot_vecs_and_labels(dims,vars,plot_timeseries,plott,sol,plot_analytic,plot_analytic_timeseries,strs)
411420
plot_vecs = []
412421
labels = String[]
413422
for x in vars
@@ -427,7 +436,7 @@ function solplot_vecs_and_labels(dims,vars,plot_timeseries,plott,sol,plot_analyt
427436
end
428437
push!(plot_vecs[i],tmp[i])
429438
end
430-
add_labels!(labels,x,dims,sol,syms)
439+
add_labels!(labels,x,dims,sol,strs)
431440
end
432441

433442

@@ -445,7 +454,7 @@ function solplot_vecs_and_labels(dims,vars,plot_timeseries,plott,sol,plot_analyt
445454
for i in eachindex(tmp)
446455
push!(plot_vecs[i],tmp[i])
447456
end
448-
add_analytic_labels!(labels,x,dims,sol,syms)
457+
add_analytic_labels!(labels,x,dims,sol,strs)
449458
end
450459
end
451460
plot_vecs = [hcat(x...) for x in plot_vecs]

0 commit comments

Comments
 (0)