Skip to content

Commit ae039e3

Browse files
Add symbolic_indexing_interface
- SymbolCache and interface functions added and exported - `states`, `parameters` and `independent_variables` hoisted down from MTK
1 parent 8005b83 commit ae039e3

File tree

3 files changed

+88
-21
lines changed

3 files changed

+88
-21
lines changed

src/RecursiveArrayTools.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ abstract type AbstractVectorOfArray{T, N, A} <: AbstractArray{T, N} end
2525
abstract type AbstractDiffEqArray{T, N, A} <: AbstractVectorOfArray{T, N, A} end
2626

2727
include("utils.jl")
28+
include("symbolic_indexing_interface.jl")
2829
include("vector_of_array.jl")
2930
include("tabletraits.jl")
3031
include("array_partition.jl")
@@ -36,6 +37,9 @@ import GPUArraysCore
3637
Base.convert(T::Type{<:GPUArraysCore.AbstractGPUArray}, VA::AbstractVectorOfArray) = T(VA)
3738
ChainRulesCore.rrule(T::Type{<:GPUArraysCore.AbstractGPUArray}, xs::AbstractVectorOfArray) = T(xs), ȳ -> (NoTangent(),ȳ)
3839

40+
export independent_variables, is_indep_sym, states, state_sym_to_index, is_state_sym,
41+
parameters, param_sym_to_index, is_param_sym, SymbolCache
42+
3943
export VectorOfArray, DiffEqArray, AbstractVectorOfArray, AbstractDiffEqArray,
4044
AllObserved, vecarr_to_arr, vecarr_to_vectors, tuples
4145

src/symbolic_indexing_interface.jl

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
"""
2+
$(TYPEDSIGNATURES)
3+
4+
Get the set of independent variables for the given system.
5+
"""
6+
function independent_variables end
7+
8+
"""
9+
$(TYPEDSIGNATURES)
10+
11+
Check if the given sym is an independent variable in the given system.
12+
"""
13+
function is_indep_sym end
14+
15+
"""
16+
$(TYPEDSIGNATURES)
17+
18+
Get the set of states for the given system.
19+
"""
20+
function states end
21+
22+
"""
23+
$(TYPEDSIGNATURES)
24+
25+
Find the index of the given sym in the given system.
26+
"""
27+
function state_sym_to_index end
28+
29+
"""
30+
$(TYPEDSIGNATURES)
31+
32+
Check if the given sym is a state variable in the given system.
33+
"""
34+
function is_state_sym end
35+
36+
"""
37+
$(TYPEDSIGNATURES)
38+
39+
Get the set of parameters variables for the given system.
40+
"""
41+
function parameters end
42+
43+
"""
44+
$(TYPEDSIGNATURES)
45+
46+
Find the index of the given sym in the given system.
47+
"""
48+
function param_sym_to_index end
49+
50+
"""
51+
$(TYPEDSIGNATURES)
52+
53+
Check if the given sym is a parameter variable in the given system.
54+
"""
55+
function is_param_sym end
56+
57+
struct SymbolCache{S,T,U}
58+
syms::S
59+
indepsym::T
60+
paramsyms::U
61+
end
62+
63+
64+
independent_variables(sc::SymbolCache) = sc.indepsym
65+
independent_variables(::SymbolCache{S,Nothing}) where {S} = []
66+
is_indep_sym(sc::SymbolCache, sym) = any(isequal(sym), sc.indepsym)
67+
is_indep_sym(::SymbolCache{S,Nothing}, _) where {S} = false
68+
states(sc::SymbolCache) = sc.syms
69+
states(::SymbolCache{Nothing}) = []
70+
state_sym_to_index(sc::SymbolCache, sym) = findfirst(isequal(sym), sc.syms)
71+
state_sym_to_index(::SymbolCache{Nothing}, _) = nothing
72+
is_state_sym(sc::SymbolCache, sym) = !isnothing(state_sym_to_index(sc, sym))
73+
parameters(sc::SymbolCache) = sc.paramsyms
74+
parameters(::SymbolCache{S,T,Nothing}) where {S,T} = []
75+
param_sym_to_index(sc::SymbolCache, sym) = findfirst(isequal(sym), sc.paramsyms)
76+
param_sym_to_index(::SymbolCache{S,T,Nothing}, _) where {S,T} = nothing
77+
is_param_sym(sc::SymbolCache, sym) = !isnothing(param_sym_to_index(sc, sym))
78+
79+
Base.copy(VA::SymbolCache) = typeof(VA)(
80+
(VA.syms===nothing) ? nothing : copy(VA.syms),
81+
(VA.indepsym===nothing) ? nothing : copy(VA.indepsym),
82+
(VA.paramsyms===nothing) ? nothing : copy(VA.paramsyms),
83+
)
84+

src/vector_of_array.jl

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -33,27 +33,6 @@ mutable struct VectorOfArray{T, N, A} <: AbstractVectorOfArray{T, N, A}
3333
end
3434
# VectorOfArray with an added series for time
3535

36-
37-
struct SymbolCache{S,T,U}
38-
syms::S
39-
indepsym::T
40-
paramsyms::U
41-
end
42-
43-
is_indep_sym(sc::SymbolCache, sym) = isequal(sc.indepsym, sym)
44-
is_indep_sym(::SymbolCache{S,Nothing}, _) where {S} = false
45-
state_sym_to_index(sc::SymbolCache, sym) = findfirst(isequal(sym), sc.syms)
46-
state_sym_to_index(::SymbolCache{Nothing}, _) = nothing
47-
is_state_sym(sc::SymbolCache, sym) = !isnothing(state_sym_to_index(sc, sym))
48-
param_sym_to_index(sc::SymbolCache, sym) = findfirst(isequal(sym), sc.paramsyms)
49-
param_sym_to_index(::SymbolCache{S,T,Nothing}, _) where {S,T} = nothing
50-
is_param_sym(sc::SymbolCache, sym) = !isnothing(param_sym_to_index(sc, sym))
51-
52-
Base.copy(VA::SymbolCache) = typeof(VA)(
53-
(VA.syms===nothing) ? nothing : copy(VA.syms),
54-
(VA.indepsym===nothing) ? nothing : copy(VA.indepsym),
55-
(VA.paramsyms===nothing) ? nothing : copy(VA.paramsyms),
56-
)
5736
"""
5837
```julia
5938
DiffEqArray(u::AbstractVector,t::AbstractVector)

0 commit comments

Comments
 (0)