Skip to content

Commit 7d5f1d5

Browse files
Merge pull request #236 from AayushSabharwal/symbolcache
Add new interface for symbolic indexing
2 parents 8f22d84 + 116942a commit 7d5f1d5

File tree

6 files changed

+62
-31
lines changed

6 files changed

+62
-31
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1616
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
1717
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
1818
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
19+
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
1920
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
2021
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
2122

@@ -30,6 +31,7 @@ GPUArraysCore = "0.1"
3031
IteratorInterfaceExtensions = "1"
3132
RecipesBase = "0.7, 0.8, 1.0"
3233
StaticArraysCore = "1.1"
34+
SymbolicIndexingInterface = "0.1"
3335
Tables = "1"
3436
ZygoteRules = "0.2"
3537
julia = "1.6"

src/RecursiveArrayTools.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ module RecursiveArrayTools
77
using DocStringExtensions
88
using RecipesBase, StaticArraysCore, Statistics,
99
ArrayInterfaceCore, LinearAlgebra
10+
using SymbolicIndexingInterface
1011

1112
import ChainRulesCore
1213
import ChainRulesCore: NoTangent

src/tabletraits.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@ function Tables.rows(A::AbstractDiffEqArray)
77
N = length(A.u[1])
88
names = [
99
:timestamp,
10-
(A.syms !== nothing ? (A.syms[i] for i in 1:N) :
10+
(A.sc !== nothing && A.sc.syms !== nothing ? (A.sc.syms[i] for i in 1:N) :
1111
(Symbol("value", i) for i in 1:N))...,
1212
]
1313
types = Type[eltype(A.t), (eltype(A.u[1]) for _ in 1:N)...]
1414
else
15-
names = [:timestamp, A.syms !== nothing ? A.syms[1] : :value]
15+
names = [:timestamp, A.sc !== nothing && A.sc.syms !== nothing ? A.sc.syms[1] : :value]
1616
types = Type[eltype(A.t), VT]
1717
end
1818
return AbstractDiffEqArrayRows(names, types, A.t, A.u)

src/vector_of_array.jl

Lines changed: 41 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -53,11 +53,10 @@ A[1,:] # all time periods for f(t)
5353
A.t
5454
```
5555
"""
56-
mutable struct DiffEqArray{T, N, A, B, C, D, E, F} <: AbstractDiffEqArray{T, N, A}
56+
mutable struct DiffEqArray{T, N, A, B, C, E, F} <: AbstractDiffEqArray{T, N, A}
5757
u::A # A <: AbstractVector{<: AbstractArray{T, N - 1}}
5858
t::B
59-
syms::C
60-
indepsym::D
59+
sc::C
6160
observed::E
6261
p::F
6362
end
@@ -94,11 +93,23 @@ VectorOfArray(vec::AbstractVector{T}, ::NTuple{N}) where {T, N} = VectorOfArray{
9493
VectorOfArray(vec::AbstractVector) = VectorOfArray(vec, (size(vec[1])..., length(vec)))
9594
VectorOfArray(vec::AbstractVector{VT}) where {T, N, VT<:AbstractArray{T, N}} = VectorOfArray{T, N+1, typeof(vec)}(vec)
9695

97-
DiffEqArray(vec::AbstractVector{T}, ts, ::NTuple{N}, syms=nothing, indepsym=nothing, observed=nothing, p=nothing) where {T, N} = DiffEqArray{eltype(T), N, typeof(vec), typeof(ts), typeof(syms), typeof(indepsym), typeof(observed), typeof(p)}(vec, ts, syms, indepsym, observed, p)
96+
function DiffEqArray(vec::AbstractVector{T}, ts, ::NTuple{N}, syms=nothing, indepsym=nothing, observed=nothing, p=nothing) where {T, N}
97+
sc = if isnothing(indepsym) || indepsym isa AbstractArray
98+
SymbolCache{typeof(syms),typeof(indepsym),Nothing}(syms, indepsym, nothing)
99+
else
100+
SymbolCache{typeof(syms),Vector{typeof(indepsym)},Nothing}(syms, [indepsym], nothing)
101+
end
102+
DiffEqArray{eltype(T), N, typeof(vec), typeof(ts), typeof(sc), typeof(observed), typeof(p)}(vec, ts, sc, observed, p)
103+
end
98104
# Assume that the first element is representative of all other elements
99105
DiffEqArray(vec::AbstractVector,ts::AbstractVector, syms=nothing, indepsym=nothing, observed=nothing, p=nothing) = DiffEqArray(vec, ts, (size(vec[1])..., length(vec)), syms, indepsym, observed, p)
100106
function DiffEqArray(vec::AbstractVector{VT},ts::AbstractVector, syms=nothing, indepsym=nothing, observed=nothing, p=nothing) where {T, N, VT<:AbstractArray{T, N}}
101-
DiffEqArray{T, N+1, typeof(vec), typeof(ts), typeof(syms), typeof(indepsym), typeof(observed), typeof(p)}(vec, ts, syms, indepsym, observed, p)
107+
sc = if isnothing(indepsym) || indepsym isa AbstractArray
108+
SymbolCache{typeof(syms),typeof(indepsym),Nothing}(syms, indepsym, nothing)
109+
else
110+
SymbolCache{typeof(syms),Vector{typeof(indepsym)},Nothing}(syms, [indepsym], nothing)
111+
end
112+
DiffEqArray{T, N+1, typeof(vec), typeof(ts), typeof(sc), typeof(observed), typeof(p)}(vec, ts, sc, observed, p)
102113
end
103114

104115
# Interface for the linear indexing. This is just a view of the underlying nested structure
@@ -138,37 +149,39 @@ Base.@propagate_inbounds Base.getindex(A::AbstractDiffEqArray{T, N}, i::Int,::Co
138149
Base.@propagate_inbounds Base.getindex(A::AbstractDiffEqArray{T, N}, ::Colon,i::Int) where {T, N} = A.u[i]
139150
Base.@propagate_inbounds Base.getindex(A::AbstractDiffEqArray{T, N}, i::Int,II::AbstractArray{Int}) where {T, N} = [A.u[j][i] for j in II]
140151
Base.@propagate_inbounds function Base.getindex(A::AbstractDiffEqArray{T, N},sym) where {T, N}
141-
if issymbollike(sym) && A.syms !== nothing
142-
i = findfirst(isequal(Symbol(sym)),A.syms)
143-
else
144-
i = sym
145-
end
146-
147-
if i === nothing
148-
if issymbollike(sym) && A.indepsym !== nothing && Symbol(sym) == A.indepsym
149-
A.t
152+
if issymbollike(sym) && !isnothing(A.sc)
153+
if is_indep_sym(A.sc, sym)
154+
return A.t
155+
elseif is_state_sym(A.sc, sym)
156+
return getindex.(A.u, state_sym_to_index(A.sc, sym))
157+
elseif is_param_sym(A.sc, sym)
158+
return A.p[param_sym_to_index(A.sc, sym)]
159+
else
160+
return observed(A, sym, :)
161+
end
162+
elseif all(issymbollike, sym) && !isnothing(A.sc)
163+
if all(Base.Fix1(is_param_sym, A.sc), sym)
164+
return getindex.((A,), sym)
150165
else
151-
observed(A,sym,:)
166+
return [getindex.((A,), sym, i) for i in eachindex(A.t)]
152167
end
153168
else
154-
Base.getindex.(A.u, i)
169+
return getindex.(A.u, sym)
155170
end
156171
end
157172
Base.@propagate_inbounds function Base.getindex(A::AbstractDiffEqArray{T, N},sym,args...) where {T, N}
158-
if issymbollike(sym) && A.syms !== nothing
159-
i = findfirst(isequal(Symbol(sym)),A.syms)
160-
else
161-
i = sym
162-
end
163-
164-
if i === nothing
165-
if issymbollike(sym) && A.indepsym !== nothing && Symbol(sym) == A.indepsym
166-
A.t[args...]
173+
if issymbollike(sym) && !isnothing(A.sc)
174+
if is_indep_sym(A.sc, sym)
175+
return A.t[args...]
176+
elseif is_state_sym(A.sc, sym)
177+
return A[sym][args...]
167178
else
168-
observed(A,sym,args...)
179+
return observed(A, sym, args...)
169180
end
181+
elseif all(issymbollike, sym) && !isnothing(A.sc)
182+
return reduce(vcat, map(s -> A[s, args...]', sym))
170183
else
171-
Base.getindex.(A.u, i, args...)
184+
return getindex.(A.u, sym)
172185
end
173186
end
174187
Base.@propagate_inbounds Base.getindex(A::AbstractDiffEqArray{T, N}, I::Int...) where {T, N} = A.u[I[end]][Base.front(I)...]
@@ -230,8 +243,7 @@ tuples(VA::DiffEqArray) = tuple.(VA.t,VA.u)
230243
Base.copy(VA::AbstractDiffEqArray) = typeof(VA)(
231244
copy(VA.u),
232245
copy(VA.t),
233-
(VA.syms===nothing) ? nothing : copy(VA.syms),
234-
(VA.indepsym===nothing) ? nothing : copy(VA.indepsym),
246+
(VA.sc===nothing) ? nothing : copy(VA.sc),
235247
(VA.observed===nothing) ? nothing : copy(VA.observed),
236248
(VA.p===nothing) ? nothing : copy(VA.p)
237249
)

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ if GROUP == "Core" || GROUP == "All"
2323
@time @testset "Utils Tests" begin include("utils_test.jl") end
2424
@time @testset "Partitions Tests" begin include("partitions_test.jl") end
2525
@time @testset "VecOfArr Indexing Tests" begin include("basic_indexing.jl") end
26+
@time @testset "SymbolicIndexingInterface API test" begin include("symbolic_indexing_interface_test.jl") end
2627
@time @testset "VecOfArr Interface Tests" begin include("interface_tests.jl") end
2728
@time @testset "Table traits" begin include("tabletraits.jl") end
2829
@time @testset "StaticArrays Tests" begin include("copy_static_array_test.jl") end
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
using RecursiveArrayTools, Test
2+
3+
t = 0.0:0.1:1.0
4+
f(x) = 2x
5+
f2(x) = 3x
6+
7+
dx = DiffEqArray([[f(x), f2(x)] for x in t], t, [:a, :b], :t)
8+
@test dx[:t] == t
9+
@test dx[:a] == [f(x) for x in t]
10+
@test dx[:b] == [f2(x) for x in t]
11+
12+
dx = DiffEqArray([[f(x), f2(x)] for x in t], t, [:a, :b], [:t])
13+
@test dx[:t] == t
14+
dx = DiffEqArray([[f(x), f2(x)] for x in t], t, [:a, :b])
15+
@test_throws Exception dx[nothing] # make sure it isn't storing [nothing] as indepsym

0 commit comments

Comments
 (0)