Skip to content

Commit b5d8407

Browse files
Much better support for labelledarrays
Automatic plot labelled, support with DAE solvers (no allocs!), etc. They should just work everywhere.
1 parent 6555bf3 commit b5d8407

File tree

6 files changed

+64
-24
lines changed

6 files changed

+64
-24
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
1313
FunctionWrappers = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e"
1414
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
1515
IteratorInterfaceExtensions = "82899510-4779-5014-852e-03e436cf321d"
16+
LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
1617
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1718
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
1819
LoggingExtras = "e6f89c97-d47a-5376-807f-9c37f3926c36"

src/DiffEqBase.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ using RecipesBase, RecursiveArrayTools,
77

88
import Logging, LoggingExtras, TerminalLoggers, ConsoleProgressMonitor, ProgressLogging
99

10-
import ZygoteRules, ChainRulesCore
10+
import ZygoteRules, ChainRulesCore, LabelledArrays
1111

1212
using Roots # callbacks
1313

src/init.jl

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ value(x) = x
22
cuify(x) = error("To use LinSolveGPUFactorize, you must do `using CuArrays`")
33
promote_u0(u0,p,t0) = u0
44
promote_tspan(u0,p,tspan,prob,kwargs) = tspan
5+
get_tmp(x) = nothing
56

67
if VERSION < v"1.4.0-DEV.635"
78
# Piracy, should get upstreamed
@@ -61,12 +62,21 @@ function __init__()
6162
end
6263

6364
function DiffCache(u::AbstractArray{T}, siz, ::Type{Val{chunk_size}}) where {T, chunk_size}
64-
DiffCache(u, zeros(ForwardDiff.Dual{nothing,T,chunk_size}, siz...))
65+
x = ArrayInterface.restructure(u,zeros(ForwardDiff.Dual{nothing,T,chunk_size}, siz...))
66+
DiffCache(u, x)
6567
end
6668

6769
dualcache(u::AbstractArray, N=Val{ForwardDiff.pickchunksize(length(u))}) = DiffCache(u, size(u), N)
6870

69-
get_tmp(dc::DiffCache, u::AbstractArray{T}) where T<:ForwardDiff.Dual = reinterpret(T, dc.dual_du)
71+
function get_tmp(dc::DiffCache, u::AbstractArray{T}) where T<:ForwardDiff.Dual
72+
x = reinterpret(T, dc.dual_du)
73+
end
74+
75+
function DiffEqBase.get_tmp(dc::DiffEqBase.DiffCache, u::LabelledArrays.LArray{T,N,D,Syms}) where {T,N,D,Syms}
76+
x = reinterpret(T, dc.dual_du.__x)
77+
LArray{T,N,D,Syms}(x)
78+
end
79+
7080
get_tmp(dc::DiffCache, u::AbstractArray) = dc.du
7181

7282
bisection(f, tup::Tuple{T,T}, tdir) where {T<:ForwardDiff.Dual} = find_zero(f, tup, Roots.AlefeldPotraShi())

src/solutions/solution_interface.jl

Lines changed: 32 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -69,11 +69,12 @@ DEFAULT_PLOT_FUNC(x,y,z) = (x,y,z) # For v0.5.2 bug
6969
tspan = nothing, axis_safety = 0.1,
7070
vars=nothing)
7171

72-
int_vars = interpret_vars(vars,sol)
72+
syms = getsyms(sol)
73+
int_vars = interpret_vars(vars,sol,syms)
7374
tscale = get(plotattributes, :xscale, :identity)
7475
plot_vecs,labels = diffeq_to_arrays(sol,plot_analytic,denseplot,
7576
plotdensity,tspan,axis_safety,
76-
vars,int_vars,tscale)
77+
vars,int_vars,tscale,syms)
7778

7879
tdir = sign(sol.t[end]-sol.t[1])
7980
xflip --> tdir < 0
@@ -146,7 +147,17 @@ DEFAULT_PLOT_FUNC(x,y,z) = (x,y,z) # For v0.5.2 bug
146147
(plot_vecs...,)
147148
end
148149

149-
function diffeq_to_arrays(sol,plot_analytic,denseplot,plotdensity,tspan,axis_safety,vars,int_vars,tscale)
150+
function getsyms(sol)
151+
if has_syms(sol.prob.f)
152+
return sol.prob.f.syms
153+
elseif typeof(sol.u[1]) <: Union{LabelledArrays.LArray,LabelledArrays.SLArray}
154+
return LabelledArrays.symnames(typeof(sol.u[1]))
155+
else
156+
return nothing
157+
end
158+
end
159+
160+
function diffeq_to_arrays(sol,plot_analytic,denseplot,plotdensity,tspan,axis_safety,vars,int_vars,tscale,syms)
150161
if tspan === nothing
151162
if sol.tslocation == 0
152163
end_idx = length(sol)
@@ -217,21 +228,21 @@ function diffeq_to_arrays(sol,plot_analytic,denseplot,plotdensity,tspan,axis_saf
217228
@assert length(var) == dims
218229
end
219230
# Should check that all have the same dims!
220-
plot_vecs,labels = solplot_vecs_and_labels(dims,int_vars,plot_timeseries,plott,sol,plot_analytic,plot_analytic_timeseries)
231+
plot_vecs,labels = solplot_vecs_and_labels(dims,int_vars,plot_timeseries,plott,sol,plot_analytic,plot_analytic_timeseries,syms)
221232
end
222233

223-
function interpret_vars(vars,sol)
224-
if vars !== nothing && has_syms(sol.prob.f)
234+
function interpret_vars(vars,sol,syms)
235+
if vars !== nothing && syms !== nothing
225236
# Do syms conversion
226237
tmp_vars = []
227238
for var in vars
228239
if typeof(var) <: Symbol
229-
var_int = something(findfirst(isequal(var), sol.prob.f.syms), 0)
240+
var_int = something(findfirst(isequal(var), syms), 0)
230241
elseif typeof(var) <: Union{Tuple,AbstractArray} #eltype(var) <: Symbol # Some kind of iterable
231242
tmp = []
232243
for x in var
233244
if typeof(x) <: Symbol
234-
push!(tmp,something(findfirst(isequal(x), sol.prob.f.syms), 0))
245+
push!(tmp,something(findfirst(isequal(x), syms), 0))
235246
else
236247
push!(tmp,x)
237248
end
@@ -314,14 +325,14 @@ function interpret_vars(vars,sol)
314325
vars
315326
end
316327

317-
function add_labels!(labels,x,dims,sol)
328+
function add_labels!(labels,x,dims,sol,syms)
318329
lys = []
319330
for j in 3:dims
320331
if x[j] == 0
321332
push!(lys,"t,")
322333
else
323-
if has_syms(sol.prob.f)
324-
push!(lys,"$(sol.prob.f.syms[x[j]]),")
334+
if syms !== nothing
335+
push!(lys,"$(syms[x[j]]),")
325336
else
326337
push!(lys,"u$(x[j]),")
327338
end
@@ -331,8 +342,8 @@ function add_labels!(labels,x,dims,sol)
331342
if x[2] == 0 && dims == 3
332343
tmp_lab = "$(lys...)(t)"
333344
else
334-
if has_syms(sol.prob.f) && x[2] != 0
335-
tmp = sol.prob.f.syms[x[2]]
345+
if syms !== nothing && x[2] != 0
346+
tmp = syms[x[2]]
336347
tmp_lab = "($tmp,$(lys...))"
337348
else
338349
if x[2] == 0
@@ -350,14 +361,14 @@ function add_labels!(labels,x,dims,sol)
350361
labels
351362
end
352363

353-
function add_analytic_labels!(labels,x,dims,sol)
364+
function add_analytic_labels!(labels,x,dims,sol,syms)
354365
lys = []
355366
for j in 3:dims
356367
if x[j] == 0 && dims == 3
357368
push!(lys,"t,")
358369
else
359-
if has_syms(sol.prob.f)
360-
push!(lys,string("True ",sol.prob.f.syms[x[j]],","))
370+
if syms !== nothing
371+
push!(lys,string("True ",syms[x[j]],","))
361372
else
362373
push!(lys,"True u$(x[j]),")
363374
end
@@ -367,8 +378,8 @@ function add_analytic_labels!(labels,x,dims,sol)
367378
if x[2] == 0
368379
tmp_lab = "$(lys...)(t)"
369380
else
370-
if has_syms(sol.prob.f)
371-
tmp = string("True ",sol.prob.f.syms[x[2]])
381+
if syms !== nothing
382+
tmp = string("True ",syms[x[2]])
372383
tmp_lab = "($tmp,$(lys...))"
373384
else
374385
tmp_lab = "(True u$(x[2]),$(lys...))"
@@ -396,7 +407,7 @@ function u_n(timeseries::AbstractArray, n::Int,sol,plott,plot_timeseries)
396407
end
397408
end
398409

399-
function solplot_vecs_and_labels(dims,vars,plot_timeseries,plott,sol,plot_analytic,plot_analytic_timeseries)
410+
function solplot_vecs_and_labels(dims,vars,plot_timeseries,plott,sol,plot_analytic,plot_analytic_timeseries,syms)
400411
plot_vecs = []
401412
labels = String[]
402413
for x in vars
@@ -416,7 +427,7 @@ function solplot_vecs_and_labels(dims,vars,plot_timeseries,plott,sol,plot_analyt
416427
end
417428
push!(plot_vecs[i],tmp[i])
418429
end
419-
add_labels!(labels,x,dims,sol)
430+
add_labels!(labels,x,dims,sol,syms)
420431
end
421432

422433

@@ -434,7 +445,7 @@ function solplot_vecs_and_labels(dims,vars,plot_timeseries,plott,sol,plot_analyt
434445
for i in eachindex(tmp)
435446
push!(plot_vecs[i],tmp[i])
436447
end
437-
add_analytic_labels!(labels,x,dims,sol)
448+
add_analytic_labels!(labels,x,dims,sol,syms)
438449
end
439450
end
440451
plot_vecs = [hcat(x...) for x in plot_vecs]

test/downstream/labelledarrays.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
using OrdinaryDiffEq
2+
using LabelledArrays
3+
4+
function f(out,du,u,p,t)
5+
out.x = - 0.04u.x + 1e4*u.y*u.z - du.x
6+
out.y = + 0.04u.x - 3e7*u.y^2 - 1e4*u.y*u.z - du.y
7+
out.z = u.x + u.y + u.z - 1.0
8+
end
9+
10+
u₀ = LVector(x=1.0, y=0.0, z=0.0)
11+
du₀ = LVector(x=-0.04, y=0.04, z=0.0)
12+
tspan = (0.0,100000.0)
13+
14+
differential_vars = LVector(x=true, y=true, z=false)
15+
prob = DAEProblem(f,du₀,u₀,tspan,differential_vars=differential_vars)
16+
17+
sol = solve(prob, DImplicitEuler())

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ if !is_APPVEYOR && GROUP == "Downstream"
4040
@time @safetestset "Ensemble Analysis" begin include("downstream/ensemble_analysis.jl") end
4141
@time @safetestset "Default linsolve with structure" begin include("downstream/default_linsolve_structure.jl") end
4242
@time @safetestset "Callback Merging Tests" begin include("downstream/callback_merging.jl") end
43+
@time @safetestset "LabelledArrays Tests" begin include("downstream/labelledarrays.jl") end
4344
@time @safetestset "ODE Event Tests" begin include("downstream/ode_event_tests.jl") end
4445
@time @safetestset "Event Detection Tests" begin include("downstream/event_detection_tests.jl") end
4546
@time @safetestset "PSOS and Energy Conservation Event Detection" begin include("downstream/psos_and_energy_conservation.jl") end

0 commit comments

Comments
 (0)