Skip to content

Commit 1607409

Browse files
Merge pull request #222 from SciML/dw/tabletraits
Implement table traits for DiffEq arrays
2 parents 6f7eedd + b26e35e commit 1607409

File tree

7 files changed

+183
-3
lines changed

7 files changed

+183
-3
lines changed

Project.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "RecursiveArrayTools"
22
uuid = "731186ca-8d62-57ce-b412-fbd966d074cd"
33
authors = ["Chris Rackauckas <[email protected]>"]
4-
version = "2.31.3"
4+
version = "2.32.0"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
@@ -11,10 +11,12 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
1111
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
1212
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
1313
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
14+
IteratorInterfaceExtensions = "82899510-4779-5014-852e-03e436cf321d"
1415
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1516
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
1617
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
1718
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
19+
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
1820
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
1921

2022
[compat]
@@ -25,8 +27,10 @@ ChainRulesCore = "0.10.7, 1"
2527
DocStringExtensions = "0.8, 0.9"
2628
FillArrays = "0.11, 0.12, 0.13"
2729
GPUArraysCore = "0.1"
30+
IteratorInterfaceExtensions = "1"
2831
RecipesBase = "0.7, 0.8, 1.0"
2932
StaticArraysCore = "1"
33+
Tables = "1"
3034
ZygoteRules = "0.2"
3135
julia = "1.6"
3236

src/RecursiveArrayTools.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,14 @@ import ArrayInterfaceStaticArraysCore
1919

2020
using FillArrays
2121

22+
import Tables, IteratorInterfaceExtensions
23+
2224
abstract type AbstractVectorOfArray{T, N, A} <: AbstractArray{T, N} end
2325
abstract type AbstractDiffEqArray{T, N, A} <: AbstractVectorOfArray{T, N, A} end
2426

2527
include("utils.jl")
2628
include("vector_of_array.jl")
29+
include("tabletraits.jl")
2730
include("array_partition.jl")
2831
include("zygote.jl")
2932

src/tabletraits.jl

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Tables traits for AbstractDiffEqArray
2+
Tables.istable(::Type{<:AbstractDiffEqArray}) = true
3+
Tables.rowaccess(::Type{<:AbstractDiffEqArray}) = true
4+
function Tables.rows(A::AbstractDiffEqArray)
5+
VT = eltype(A.u)
6+
if VT <: AbstractArray
7+
N = length(A.u[1])
8+
names = [
9+
:timestamp,
10+
(A.syms !== nothing ? (A.syms[i] for i in 1:N) :
11+
(Symbol("value", i) for i in 1:N))...,
12+
]
13+
types = Type[eltype(A.t), (eltype(A.u[1]) for _ in 1:N)...]
14+
else
15+
names = [:timestamp, A.syms !== nothing ? A.syms[1] : :value]
16+
types = Type[eltype(A.t), VT]
17+
end
18+
return AbstractDiffEqArrayRows(names, types, A.t, A.u)
19+
end
20+
21+
# Override fallback definitions for AbstractMatrix
22+
Tables.istable(::AbstractDiffEqArray) = true # Ref: https://github.com/JuliaData/Tables.jl/pull/198
23+
Tables.columns(x::AbstractDiffEqArray) = Tables.columntable(Tables.rows(x))
24+
25+
# Iterator of Tables.AbstractRow rows
26+
struct AbstractDiffEqArrayRows{T, U}
27+
names::Vector{Symbol}
28+
types::Vector{Type}
29+
lookup::Dict{Symbol, Int}
30+
t::T
31+
u::U
32+
end
33+
function AbstractDiffEqArrayRows(names, types, t, u)
34+
AbstractDiffEqArrayRows(names, types,
35+
Dict(nm => i for (i, nm) in enumerate(names)), t, u)
36+
end
37+
38+
Base.length(x::AbstractDiffEqArrayRows) = length(x.u)
39+
function Base.eltype(::Type{AbstractDiffEqArrayRows{T, U}}) where {T, U}
40+
AbstractDiffEqArrayRow{eltype(T), eltype(U)}
41+
end
42+
function Base.iterate(x::AbstractDiffEqArrayRows, (t_state, u_state)=(iterate(x.t), iterate(x.u)))
43+
t_state === nothing && return nothing
44+
u_state === nothing && return nothing
45+
t, _t_state = t_state
46+
u, _u_state = u_state
47+
st = (iterate(x.t, _t_state), iterate(x.u, _u_state))
48+
return (AbstractDiffEqArrayRow(x.names, x.lookup, t, u), st)
49+
end
50+
51+
Tables.istable(::Type{<:AbstractDiffEqArrayRows}) = true
52+
Tables.rowaccess(::Type{<:AbstractDiffEqArrayRows}) = true
53+
Tables.rows(x::AbstractDiffEqArrayRows) = x
54+
Tables.schema(x::AbstractDiffEqArrayRows) = Tables.Schema(x.names, x.types)
55+
56+
# AbstractRow subtype
57+
struct AbstractDiffEqArrayRow{T, U} <: Tables.AbstractRow
58+
names::Vector{Symbol}
59+
lookup::Dict{Symbol, Int}
60+
t::T
61+
u::U
62+
end
63+
64+
Tables.columnnames(x::AbstractDiffEqArrayRow) = getfield(x, :names)
65+
function Tables.getcolumn(x::AbstractDiffEqArrayRow, i::Int)
66+
i == 1 ? getfield(x, :t) : getfield(x, :u)[i - 1]
67+
end
68+
function Tables.getcolumn(x::AbstractDiffEqArrayRow, nm::Symbol)
69+
nm === :timestamp ? getfield(x, :t) : getfield(x, :u)[getfield(x, :lookup)[nm] - 1]
70+
end
71+
72+
# Iterator interface for QueryVerse
73+
# (see also https://tables.juliadata.org/stable/#Tables.datavaluerows)
74+
IteratorInterfaceExtensions.isiterable(::AbstractDiffEqArray) = true
75+
function IteratorInterfaceExtensions.getiterator(A::AbstractDiffEqArray)
76+
Tables.datavaluerows(Tables.rows(A))
77+
end

test/downstream/symbol_indexing.jl

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
using RecursiveArrayTools, ModelingToolkit, OrdinaryDiffEq, Test
22

3+
include("../testutils.jl")
4+
35
@variables t x(t)
46
@parameters τ
57
D = Differential(t)
68
@variables RHS(t)
79
@named fol_separate = ODESystem([ RHS ~ (1 - x)/τ,
810
D(x) ~ RHS ])
9-
fol_simplified = structural_simplify(fol_separate)
11+
fol_simplified = structural_simplify(fol_separate)
1012

1113
prob = ODEProblem(fol_simplified, [x => 0.0], (0.0,10.0), [τ => 3.0])
1214
sol = solve(prob, Tsit5())
@@ -22,4 +24,22 @@ sol_new = DiffEqArray(
2224

2325
@test sol_new[RHS] (1 .- sol_new[x])./3.0
2426
@test sol_new[t] sol_new.t
25-
@test sol_new[t, 1:5] sol_new.t[1:5]
27+
@test sol_new[t, 1:5] sol_new.t[1:5]
28+
29+
# Tables interface
30+
test_tables_interface(sol_new, [:timestamp, Symbol("x(t)")], hcat(sol_new[t], sol_new[x]))
31+
32+
# Two components
33+
@variables y(t)
34+
@parameters α β γ δ
35+
@named lv = ODESystem([ D(x) ~ α * x - β * x * y,
36+
D(y) ~ δ * x * y - γ * x * y])
37+
38+
prob = ODEProblem(lv, [x => 1.0, y => 1.0], (0.0, 10.0),
39+
=> 1.5, β => 1.0, γ => 3.0, δ => 1.0])
40+
sol = solve(prob, Tsit5())
41+
42+
ts = 0:0.5:10
43+
sol_ts = sol(ts)
44+
@assert sol_ts isa DiffEqArray
45+
test_tables_interface(sol_ts, [:timestamp, Symbol("x(t)"), Symbol("y(t)")], hcat(ts, Array(sol_ts)'))

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ if GROUP == "Core" || GROUP == "All"
2424
@time @testset "Partitions Tests" begin include("partitions_test.jl") end
2525
@time @testset "VecOfArr Indexing Tests" begin include("basic_indexing.jl") end
2626
@time @testset "VecOfArr Interface Tests" begin include("interface_tests.jl") end
27+
@time @testset "Table traits" begin include("tabletraits.jl") end
2728
@time @testset "StaticArrays Tests" begin include("copy_static_array_test.jl") end
2829
@time @testset "Linear Algebra Tests" begin include("linalg.jl") end
2930
@time @testset "Upstream Tests" begin include("upstream.jl") end

test/tabletraits.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
using RecursiveArrayTools, Random, Test
2+
3+
include("testutils.jl")
4+
5+
Random.seed!(1234)
6+
7+
n = 20
8+
t = sort(randn(n))
9+
u = randn(n)
10+
A = DiffEqArray(u, t)
11+
test_tables_interface(A, [:timestamp, :value], hcat(t, u))
12+
13+
u = [randn(3) for _ in 1:n]
14+
A = DiffEqArray(u, t)
15+
test_tables_interface(A, [:timestamp, :value1, :value2, :value3], hcat(t, reduce(vcat, u')))

test/testutils.jl

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
using RecursiveArrayTools
2+
using RecursiveArrayTools: Tables, IteratorInterfaceExtensions
3+
4+
# Test Tables interface with row access + IteratorInterfaceExtensions for QueryVerse
5+
# (see https://tables.juliadata.org/stable/#Testing-Tables.jl-Implementations)
6+
function test_tables_interface(x::AbstractDiffEqArray, names::Vector{Symbol}, values::Matrix)
7+
@assert length(names) == size(values, 2)
8+
9+
# AbstractDiffEqArray is a table with row access
10+
@test Tables.istable(x)
11+
@test Tables.istable(typeof(x))
12+
@test Tables.rowaccess(x)
13+
@test Tables.rowaccess(typeof(x))
14+
@test !Tables.columnaccess(x)
15+
@test !Tables.columnaccess(typeof(x))
16+
17+
# Check implementation of AbstractRow iterator
18+
tbl = Tables.rows(x)
19+
@test length(tbl) == size(values, 1)
20+
@test Tables.istable(tbl)
21+
@test Tables.istable(typeof(tbl))
22+
@test Tables.rowaccess(tbl)
23+
@test Tables.rowaccess(typeof(tbl))
24+
@test Tables.rows(tbl) === tbl
25+
26+
# Check implementation of AbstractRow subtype
27+
for (i, row) in enumerate(tbl)
28+
@test eltype(tbl) === typeof(row)
29+
@test propertynames(row) == Tables.columnnames(row) == names
30+
for (j, name) in enumerate(names)
31+
@test getproperty(row, name) == Tables.getcolumn(row, name) == Tables.getcolumn(row, j) == values[i, j]
32+
end
33+
end
34+
35+
# Check column access
36+
coltbl = Tables.columns(x)
37+
@test length(coltbl) == size(values, 2)
38+
@test Tables.istable(coltbl)
39+
@test Tables.istable(typeof(coltbl))
40+
@test Tables.columnaccess(coltbl)
41+
@test Tables.columnaccess(typeof(coltbl))
42+
@test Tables.columns(coltbl) === coltbl
43+
@test propertynames(coltbl) == Tables.columnnames(coltbl) == Tuple(names)
44+
for (i, name) in enumerate(names)
45+
@test getproperty(coltbl, name) == Tables.getcolumn(coltbl, name) == Tables.getcolumn(coltbl, i) == values[:, i]
46+
end
47+
48+
# IteratorInterfaceExtensions
49+
@test IteratorInterfaceExtensions.isiterable(x)
50+
iterator = IteratorInterfaceExtensions.getiterator(x)
51+
for (i, row) in enumerate(iterator)
52+
@test row isa NamedTuple
53+
@test propertynames(row) == Tuple(names)
54+
for (j, name) in enumerate(names)
55+
@test getproperty(row, name) == row[j] == values[i, j]
56+
end
57+
end
58+
59+
nothing
60+
end

0 commit comments

Comments
 (0)