Skip to content

Commit ae94d75

Browse files
Merge pull request #475 from SciML/labelledarrays
Much better support for labelledarrays
2 parents 30aa4e3 + da3da9e commit ae94d75

File tree

8 files changed

+92
-43
lines changed

8 files changed

+92
-43
lines changed

.gitlab-ci.yml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,6 @@ variables:
77
stages:
88
- build
99
- test
10-
11-
job:
12-
cache: {}
1310

1411
build:
1512
stage: build

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
1313
FunctionWrappers = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e"
1414
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
1515
IteratorInterfaceExtensions = "82899510-4779-5014-852e-03e436cf321d"
16+
LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
1617
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1718
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
1819
LoggingExtras = "e6f89c97-d47a-5376-807f-9c37f3926c36"

src/DiffEqBase.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ using RecipesBase, RecursiveArrayTools,
77

88
import Logging, LoggingExtras, TerminalLoggers, ConsoleProgressMonitor, ProgressLogging
99

10-
import ZygoteRules, ChainRulesCore
10+
import ZygoteRules, ChainRulesCore, LabelledArrays
1111

1212
using Roots # callbacks
1313

src/init.jl

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ value(x) = x
22
cuify(x) = error("To use LinSolveGPUFactorize, you must do `using CuArrays`")
33
promote_u0(u0,p,t0) = u0
44
promote_tspan(u0,p,tspan,prob,kwargs) = tspan
5+
get_tmp(x) = nothing
56

67
if VERSION < v"1.4.0-DEV.635"
78
# Piracy, should get upstreamed
@@ -64,12 +65,21 @@ function __init__()
6465
end
6566

6667
function DiffCache(u::AbstractArray{T}, siz, ::Type{Val{chunk_size}}) where {T, chunk_size}
67-
DiffCache(u, zeros(ForwardDiff.Dual{nothing,T,chunk_size}, siz...))
68+
x = ArrayInterface.restructure(u,zeros(ForwardDiff.Dual{nothing,T,chunk_size}, siz...))
69+
DiffCache(u, x)
6870
end
6971

7072
dualcache(u::AbstractArray, N=Val{ForwardDiff.pickchunksize(length(u))}) = DiffCache(u, size(u), N)
7173

72-
get_tmp(dc::DiffCache, u::AbstractArray{T}) where T<:ForwardDiff.Dual = reinterpret(T, dc.dual_du)
74+
function get_tmp(dc::DiffCache, u::AbstractArray{T}) where T<:ForwardDiff.Dual
75+
x = reinterpret(T, dc.dual_du)
76+
end
77+
78+
function DiffEqBase.get_tmp(dc::DiffEqBase.DiffCache, u::LabelledArrays.LArray{T,N,D,Syms}) where {T,N,D,Syms}
79+
x = reinterpret(T, dc.dual_du.__x)
80+
LabelledArrays.LArray{T,N,D,Syms}(x)
81+
end
82+
7383
get_tmp(dc::DiffCache, u::AbstractArray) = dc.du
7484

7585
bisection(f, tup::Tuple{T,T}, tdir) where {T<:ForwardDiff.Dual} = find_zero(f, tup, Roots.AlefeldPotraShi())

src/solutions/solution_interface.jl

Lines changed: 32 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -69,11 +69,12 @@ DEFAULT_PLOT_FUNC(x,y,z) = (x,y,z) # For v0.5.2 bug
6969
tspan = nothing, axis_safety = 0.1,
7070
vars=nothing)
7171

72-
int_vars = interpret_vars(vars,sol)
72+
syms = getsyms(sol)
73+
int_vars = interpret_vars(vars,sol,syms)
7374
tscale = get(plotattributes, :xscale, :identity)
7475
plot_vecs,labels = diffeq_to_arrays(sol,plot_analytic,denseplot,
7576
plotdensity,tspan,axis_safety,
76-
vars,int_vars,tscale)
77+
vars,int_vars,tscale,syms)
7778

7879
tdir = sign(sol.t[end]-sol.t[1])
7980
xflip --> tdir < 0
@@ -146,7 +147,17 @@ DEFAULT_PLOT_FUNC(x,y,z) = (x,y,z) # For v0.5.2 bug
146147
(plot_vecs...,)
147148
end
148149

149-
function diffeq_to_arrays(sol,plot_analytic,denseplot,plotdensity,tspan,axis_safety,vars,int_vars,tscale)
150+
function getsyms(sol)
151+
if has_syms(sol.prob.f)
152+
return sol.prob.f.syms
153+
elseif typeof(sol.u[1]) <: Union{LabelledArrays.LArray,LabelledArrays.SLArray}
154+
return LabelledArrays.symnames(typeof(sol.u[1]))
155+
else
156+
return nothing
157+
end
158+
end
159+
160+
function diffeq_to_arrays(sol,plot_analytic,denseplot,plotdensity,tspan,axis_safety,vars,int_vars,tscale,syms)
150161
if tspan === nothing
151162
if sol.tslocation == 0
152163
end_idx = length(sol)
@@ -217,21 +228,21 @@ function diffeq_to_arrays(sol,plot_analytic,denseplot,plotdensity,tspan,axis_saf
217228
@assert length(var) == dims
218229
end
219230
# Should check that all have the same dims!
220-
plot_vecs,labels = solplot_vecs_and_labels(dims,int_vars,plot_timeseries,plott,sol,plot_analytic,plot_analytic_timeseries)
231+
plot_vecs,labels = solplot_vecs_and_labels(dims,int_vars,plot_timeseries,plott,sol,plot_analytic,plot_analytic_timeseries,syms)
221232
end
222233

223-
function interpret_vars(vars,sol)
224-
if vars !== nothing && has_syms(sol.prob.f)
234+
function interpret_vars(vars,sol,syms)
235+
if vars !== nothing && syms !== nothing
225236
# Do syms conversion
226237
tmp_vars = []
227238
for var in vars
228239
if typeof(var) <: Symbol
229-
var_int = something(findfirst(isequal(var), sol.prob.f.syms), 0)
240+
var_int = something(findfirst(isequal(var), syms), 0)
230241
elseif typeof(var) <: Union{Tuple,AbstractArray} #eltype(var) <: Symbol # Some kind of iterable
231242
tmp = []
232243
for x in var
233244
if typeof(x) <: Symbol
234-
push!(tmp,something(findfirst(isequal(x), sol.prob.f.syms), 0))
245+
push!(tmp,something(findfirst(isequal(x), syms), 0))
235246
else
236247
push!(tmp,x)
237248
end
@@ -314,14 +325,14 @@ function interpret_vars(vars,sol)
314325
vars
315326
end
316327

317-
function add_labels!(labels,x,dims,sol)
328+
function add_labels!(labels,x,dims,sol,syms)
318329
lys = []
319330
for j in 3:dims
320331
if x[j] == 0
321332
push!(lys,"t,")
322333
else
323-
if has_syms(sol.prob.f)
324-
push!(lys,"$(sol.prob.f.syms[x[j]]),")
334+
if syms !== nothing
335+
push!(lys,"$(syms[x[j]]),")
325336
else
326337
push!(lys,"u$(x[j]),")
327338
end
@@ -331,8 +342,8 @@ function add_labels!(labels,x,dims,sol)
331342
if x[2] == 0 && dims == 3
332343
tmp_lab = "$(lys...)(t)"
333344
else
334-
if has_syms(sol.prob.f) && x[2] != 0
335-
tmp = sol.prob.f.syms[x[2]]
345+
if syms !== nothing && x[2] != 0
346+
tmp = syms[x[2]]
336347
tmp_lab = "($tmp,$(lys...))"
337348
else
338349
if x[2] == 0
@@ -350,14 +361,14 @@ function add_labels!(labels,x,dims,sol)
350361
labels
351362
end
352363

353-
function add_analytic_labels!(labels,x,dims,sol)
364+
function add_analytic_labels!(labels,x,dims,sol,syms)
354365
lys = []
355366
for j in 3:dims
356367
if x[j] == 0 && dims == 3
357368
push!(lys,"t,")
358369
else
359-
if has_syms(sol.prob.f)
360-
push!(lys,string("True ",sol.prob.f.syms[x[j]],","))
370+
if syms !== nothing
371+
push!(lys,string("True ",syms[x[j]],","))
361372
else
362373
push!(lys,"True u$(x[j]),")
363374
end
@@ -367,8 +378,8 @@ function add_analytic_labels!(labels,x,dims,sol)
367378
if x[2] == 0
368379
tmp_lab = "$(lys...)(t)"
369380
else
370-
if has_syms(sol.prob.f)
371-
tmp = string("True ",sol.prob.f.syms[x[2]])
381+
if syms !== nothing
382+
tmp = string("True ",syms[x[2]])
372383
tmp_lab = "($tmp,$(lys...))"
373384
else
374385
tmp_lab = "(True u$(x[2]),$(lys...))"
@@ -396,7 +407,7 @@ function u_n(timeseries::AbstractArray, n::Int,sol,plott,plot_timeseries)
396407
end
397408
end
398409

399-
function solplot_vecs_and_labels(dims,vars,plot_timeseries,plott,sol,plot_analytic,plot_analytic_timeseries)
410+
function solplot_vecs_and_labels(dims,vars,plot_timeseries,plott,sol,plot_analytic,plot_analytic_timeseries,syms)
400411
plot_vecs = []
401412
labels = String[]
402413
for x in vars
@@ -416,7 +427,7 @@ function solplot_vecs_and_labels(dims,vars,plot_timeseries,plott,sol,plot_analyt
416427
end
417428
push!(plot_vecs[i],tmp[i])
418429
end
419-
add_labels!(labels,x,dims,sol)
430+
add_labels!(labels,x,dims,sol,syms)
420431
end
421432

422433

@@ -434,7 +445,7 @@ function solplot_vecs_and_labels(dims,vars,plot_timeseries,plott,sol,plot_analyt
434445
for i in eachindex(tmp)
435446
push!(plot_vecs[i],tmp[i])
436447
end
437-
add_analytic_labels!(labels,x,dims,sol)
448+
add_analytic_labels!(labels,x,dims,sol,syms)
438449
end
439450
end
440451
plot_vecs = [hcat(x...) for x in plot_vecs]

test/downstream/labelledarrays.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
using OrdinaryDiffEq
2+
using LabelledArrays
3+
4+
function f(out,du,u,p,t)
5+
out.x = - 0.04u.x + 1e4*u.y*u.z - du.x
6+
out.y = + 0.04u.x - 3e7*u.y^2 - 1e4*u.y*u.z - du.y
7+
out.z = u.x + u.y + u.z - 1.0
8+
end
9+
10+
u₀ = LVector(x=1.0, y=0.0, z=0.0)
11+
du₀ = LVector(x=-0.04, y=0.04, z=0.0)
12+
tspan = (0.0,100000.0)
13+
14+
differential_vars = LVector(x=true, y=true, z=false)
15+
prob = DAEProblem(f,du₀,u₀,tspan,differential_vars=differential_vars)
16+
17+
sol = solve(prob, DImplicitEuler())
18+
19+
function f1(du,u,p,t)
20+
du.x .= -1 .* u.x .* u.y .* p[1]
21+
du.y .= -1 .* u.y .* p[2]
22+
end
23+
const n = 1000
24+
u_0 = @LArray fill(1000.0,2*n) (x = (1:n),y = (n+1:2*n))
25+
p = [0.1,0.1]
26+
prob1 = ODEProblem(f1,u_0,(0,100.0),p)
27+
sol = solve(prob1, Rodas5())
28+
sol = solve(prob1, Rodas5(autodiff=false))

test/plot_vars.jl

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,24 +16,25 @@ tspan = (0., 100.)
1616
prob = ODEProblem(lorenz, u0, tspan)
1717
dt = 0.1
1818
sol = solve(prob,InternalEuler.FwdEulerAlg(), tstops=0:dt:1)
19+
syms = [:x,:y,:z]
1920

2021
@test DiffEqBase.has_syms(prob.f) == true
21-
@test DiffEqBase.interpret_vars([(0,1), (1,3), (4,5)],sol) == [(DiffEqBase.DEFAULT_PLOT_FUNC,0,1), (DiffEqBase.DEFAULT_PLOT_FUNC,1,3), (DiffEqBase.DEFAULT_PLOT_FUNC,4,5)]
22-
@test DiffEqBase.interpret_vars([1, (1,3), (4,5)],sol) == [(DiffEqBase.DEFAULT_PLOT_FUNC,0,1), (DiffEqBase.DEFAULT_PLOT_FUNC,1,3), (DiffEqBase.DEFAULT_PLOT_FUNC,4,5)]
23-
@test DiffEqBase.interpret_vars([1, 3, 4],sol) == [(DiffEqBase.DEFAULT_PLOT_FUNC,0,1), (DiffEqBase.DEFAULT_PLOT_FUNC,0,3), (DiffEqBase.DEFAULT_PLOT_FUNC,0,4)]
24-
@test DiffEqBase.interpret_vars(([1,2,3], [4,5,6]),sol) == [(DiffEqBase.DEFAULT_PLOT_FUNC,1,4), (DiffEqBase.DEFAULT_PLOT_FUNC,2,5), (DiffEqBase.DEFAULT_PLOT_FUNC,3,6)]
25-
@test DiffEqBase.interpret_vars((1, [2,3,4]),sol) == [(DiffEqBase.DEFAULT_PLOT_FUNC,1,2), (DiffEqBase.DEFAULT_PLOT_FUNC,1,3), (DiffEqBase.DEFAULT_PLOT_FUNC,1,4)]
22+
@test DiffEqBase.interpret_vars([(0,1), (1,3), (4,5)],sol,syms) == [(DiffEqBase.DEFAULT_PLOT_FUNC,0,1), (DiffEqBase.DEFAULT_PLOT_FUNC,1,3), (DiffEqBase.DEFAULT_PLOT_FUNC,4,5)]
23+
@test DiffEqBase.interpret_vars([1, (1,3), (4,5)],sol,syms) == [(DiffEqBase.DEFAULT_PLOT_FUNC,0,1), (DiffEqBase.DEFAULT_PLOT_FUNC,1,3), (DiffEqBase.DEFAULT_PLOT_FUNC,4,5)]
24+
@test DiffEqBase.interpret_vars([1, 3, 4],sol,syms) == [(DiffEqBase.DEFAULT_PLOT_FUNC,0,1), (DiffEqBase.DEFAULT_PLOT_FUNC,0,3), (DiffEqBase.DEFAULT_PLOT_FUNC,0,4)]
25+
@test DiffEqBase.interpret_vars(([1,2,3], [4,5,6]),sol,syms) == [(DiffEqBase.DEFAULT_PLOT_FUNC,1,4), (DiffEqBase.DEFAULT_PLOT_FUNC,2,5), (DiffEqBase.DEFAULT_PLOT_FUNC,3,6)]
26+
@test DiffEqBase.interpret_vars((1, [2,3,4]),sol,syms) == [(DiffEqBase.DEFAULT_PLOT_FUNC,1,2), (DiffEqBase.DEFAULT_PLOT_FUNC,1,3), (DiffEqBase.DEFAULT_PLOT_FUNC,1,4)]
2627

27-
@test DiffEqBase.interpret_vars([(:t,:x),(:t,:y)],sol) == [(DiffEqBase.DEFAULT_PLOT_FUNC,0,1), (DiffEqBase.DEFAULT_PLOT_FUNC,0,2)]
28-
@test DiffEqBase.interpret_vars([:x, (0,:x), (:x,:y)],sol) == [(DiffEqBase.DEFAULT_PLOT_FUNC,0,1), (DiffEqBase.DEFAULT_PLOT_FUNC,0,1), (DiffEqBase.DEFAULT_PLOT_FUNC,1,2)]
29-
@test DiffEqBase.interpret_vars([:x, :y, :z],sol) == [(DiffEqBase.DEFAULT_PLOT_FUNC,0,1), (DiffEqBase.DEFAULT_PLOT_FUNC,0,2), (DiffEqBase.DEFAULT_PLOT_FUNC,0,3)]
30-
@test DiffEqBase.interpret_vars(([:x,:x], [:y,:z]),sol) == [(DiffEqBase.DEFAULT_PLOT_FUNC,1,2), (DiffEqBase.DEFAULT_PLOT_FUNC,1,3)]
31-
@test DiffEqBase.interpret_vars((:x, [:y,:z]),sol) == [(DiffEqBase.DEFAULT_PLOT_FUNC,1,2), (DiffEqBase.DEFAULT_PLOT_FUNC,1,3)]
28+
@test DiffEqBase.interpret_vars([(:t,:x),(:t,:y)],sol,syms) == [(DiffEqBase.DEFAULT_PLOT_FUNC,0,1), (DiffEqBase.DEFAULT_PLOT_FUNC,0,2)]
29+
@test DiffEqBase.interpret_vars([:x, (0,:x), (:x,:y)],sol,syms) == [(DiffEqBase.DEFAULT_PLOT_FUNC,0,1), (DiffEqBase.DEFAULT_PLOT_FUNC,0,1), (DiffEqBase.DEFAULT_PLOT_FUNC,1,2)]
30+
@test DiffEqBase.interpret_vars([:x, :y, :z],sol,syms) == [(DiffEqBase.DEFAULT_PLOT_FUNC,0,1), (DiffEqBase.DEFAULT_PLOT_FUNC,0,2), (DiffEqBase.DEFAULT_PLOT_FUNC,0,3)]
31+
@test DiffEqBase.interpret_vars(([:x,:x], [:y,:z]),sol,syms) == [(DiffEqBase.DEFAULT_PLOT_FUNC,1,2), (DiffEqBase.DEFAULT_PLOT_FUNC,1,3)]
32+
@test DiffEqBase.interpret_vars((:x, [:y,:z]),sol,syms) == [(DiffEqBase.DEFAULT_PLOT_FUNC,1,2), (DiffEqBase.DEFAULT_PLOT_FUNC,1,3)]
3233

3334
f(x,y) = (x+y,y)
34-
@test DiffEqBase.interpret_vars([(f,0,1), (1,3), (4,5)],sol) == [(f,0,1), (DiffEqBase.DEFAULT_PLOT_FUNC,1,3), (DiffEqBase.DEFAULT_PLOT_FUNC,4,5)]
35-
@test DiffEqBase.interpret_vars([1, (f,1,3), (4,5)],sol) == [(DiffEqBase.DEFAULT_PLOT_FUNC,0,1), (f,1,3), (DiffEqBase.DEFAULT_PLOT_FUNC,4,5)]
36-
@test DiffEqBase.interpret_vars([(f,:t,:x),(:t,:y)],sol) == [(f,0,1), (DiffEqBase.DEFAULT_PLOT_FUNC,0,2)]
37-
@test DiffEqBase.interpret_vars([:x, (f,0,:x), (:x,:y)],sol) == [(DiffEqBase.DEFAULT_PLOT_FUNC,0,1), (f,0,1), (DiffEqBase.DEFAULT_PLOT_FUNC,1,2)]
38-
@test DiffEqBase.interpret_vars([(:x,:y)],sol) == [(DiffEqBase.DEFAULT_PLOT_FUNC,1,2)]
39-
@test DiffEqBase.interpret_vars((f,:x,:y),sol) == [(f,1,2)]
35+
@test DiffEqBase.interpret_vars([(f,0,1), (1,3), (4,5)],sol,syms) == [(f,0,1), (DiffEqBase.DEFAULT_PLOT_FUNC,1,3), (DiffEqBase.DEFAULT_PLOT_FUNC,4,5)]
36+
@test DiffEqBase.interpret_vars([1, (f,1,3), (4,5)],sol,syms) == [(DiffEqBase.DEFAULT_PLOT_FUNC,0,1), (f,1,3), (DiffEqBase.DEFAULT_PLOT_FUNC,4,5)]
37+
@test DiffEqBase.interpret_vars([(f,:t,:x),(:t,:y)],sol,syms) == [(f,0,1), (DiffEqBase.DEFAULT_PLOT_FUNC,0,2)]
38+
@test DiffEqBase.interpret_vars([:x, (f,0,:x), (:x,:y)],sol,syms) == [(DiffEqBase.DEFAULT_PLOT_FUNC,0,1), (f,0,1), (DiffEqBase.DEFAULT_PLOT_FUNC,1,2)]
39+
@test DiffEqBase.interpret_vars([(:x,:y)],sol,syms) == [(DiffEqBase.DEFAULT_PLOT_FUNC,1,2)]
40+
@test DiffEqBase.interpret_vars((f,:x,:y),sol,syms) == [(f,1,2)]

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ if !is_APPVEYOR && GROUP == "Downstream"
4040
@time @safetestset "Ensemble Analysis" begin include("downstream/ensemble_analysis.jl") end
4141
@time @safetestset "Default linsolve with structure" begin include("downstream/default_linsolve_structure.jl") end
4242
@time @safetestset "Callback Merging Tests" begin include("downstream/callback_merging.jl") end
43+
@time @safetestset "LabelledArrays Tests" begin include("downstream/labelledarrays.jl") end
4344
@time @safetestset "ODE Event Tests" begin include("downstream/ode_event_tests.jl") end
4445
@time @safetestset "Event Detection Tests" begin include("downstream/event_detection_tests.jl") end
4546
@time @safetestset "PSOS and Energy Conservation Event Detection" begin include("downstream/psos_and_energy_conservation.jl") end

0 commit comments

Comments
 (0)