@@ -53,11 +53,10 @@ A[1,:] # all time periods for f(t)
5353A.t
5454```
5555"""
56- mutable struct DiffEqArray{T, N, A, B, C, D, E, F} <: AbstractDiffEqArray{T, N, A}
56+ mutable struct DiffEqArray{T, N, A, B, C, E, F} <: AbstractDiffEqArray{T, N, A}
5757 u:: A # A <: AbstractVector{<: AbstractArray{T, N - 1}}
5858 t:: B
59- syms:: C
60- indepsym:: D
59+ sc:: C
6160 observed:: E
6261 p:: F
6362end
@@ -94,11 +93,23 @@ VectorOfArray(vec::AbstractVector{T}, ::NTuple{N}) where {T, N} = VectorOfArray{
9493VectorOfArray (vec:: AbstractVector ) = VectorOfArray (vec, (size (vec[1 ])... , length (vec)))
9594VectorOfArray (vec:: AbstractVector{VT} ) where {T, N, VT<: AbstractArray{T, N} } = VectorOfArray {T, N+1, typeof(vec)} (vec)
9695
97- DiffEqArray (vec:: AbstractVector{T} , ts, :: NTuple{N} , syms= nothing , indepsym= nothing , observed= nothing , p= nothing ) where {T, N} = DiffEqArray {eltype(T), N, typeof(vec), typeof(ts), typeof(syms), typeof(indepsym), typeof(observed), typeof(p)} (vec, ts, syms, indepsym, observed, p)
96+ function DiffEqArray (vec:: AbstractVector{T} , ts, :: NTuple{N} , syms= nothing , indepsym= nothing , observed= nothing , p= nothing ) where {T, N}
97+ sc = if isnothing (indepsym) || indepsym isa AbstractArray
98+ SymbolCache {typeof(syms),typeof(indepsym),Nothing} (syms, indepsym, nothing )
99+ else
100+ SymbolCache {typeof(syms),Vector{typeof(indepsym)},Nothing} (syms, [indepsym], nothing )
101+ end
102+ DiffEqArray {eltype(T), N, typeof(vec), typeof(ts), typeof(sc), typeof(observed), typeof(p)} (vec, ts, sc, observed, p)
103+ end
98104# Assume that the first element is representative of all other elements
99105DiffEqArray (vec:: AbstractVector ,ts:: AbstractVector , syms= nothing , indepsym= nothing , observed= nothing , p= nothing ) = DiffEqArray (vec, ts, (size (vec[1 ])... , length (vec)), syms, indepsym, observed, p)
100106function DiffEqArray (vec:: AbstractVector{VT} ,ts:: AbstractVector , syms= nothing , indepsym= nothing , observed= nothing , p= nothing ) where {T, N, VT<: AbstractArray{T, N} }
101- DiffEqArray {T, N+1, typeof(vec), typeof(ts), typeof(syms), typeof(indepsym), typeof(observed), typeof(p)} (vec, ts, syms, indepsym, observed, p)
107+ sc = if isnothing (indepsym) || indepsym isa AbstractArray
108+ SymbolCache {typeof(syms),typeof(indepsym),Nothing} (syms, indepsym, nothing )
109+ else
110+ SymbolCache {typeof(syms),Vector{typeof(indepsym)},Nothing} (syms, [indepsym], nothing )
111+ end
112+ DiffEqArray {T, N+1, typeof(vec), typeof(ts), typeof(sc), typeof(observed), typeof(p)} (vec, ts, sc, observed, p)
102113end
103114
104115# Interface for the linear indexing. This is just a view of the underlying nested structure
@@ -138,37 +149,39 @@ Base.@propagate_inbounds Base.getindex(A::AbstractDiffEqArray{T, N}, i::Int,::Co
138149Base. @propagate_inbounds Base. getindex (A:: AbstractDiffEqArray{T, N} , :: Colon ,i:: Int ) where {T, N} = A. u[i]
139150Base. @propagate_inbounds Base. getindex (A:: AbstractDiffEqArray{T, N} , i:: Int ,II:: AbstractArray{Int} ) where {T, N} = [A. u[j][i] for j in II]
140151Base. @propagate_inbounds function Base. getindex (A:: AbstractDiffEqArray{T, N} ,sym) where {T, N}
141- if issymbollike (sym) && A. syms != = nothing
142- i = findfirst (isequal (Symbol (sym)),A. syms)
143- else
144- i = sym
145- end
146-
147- if i === nothing
148- if issymbollike (sym) && A. indepsym != = nothing && Symbol (sym) == A. indepsym
149- A. t
152+ if issymbollike (sym) && ! isnothing (A. sc)
153+ if is_indep_sym (A. sc, sym)
154+ return A. t
155+ elseif is_state_sym (A. sc, sym)
156+ return getindex .(A. u, state_sym_to_index (A. sc, sym))
157+ elseif is_param_sym (A. sc, sym)
158+ return A. p[param_sym_to_index (A. sc, sym)]
159+ else
160+ return observed (A, sym, :)
161+ end
162+ elseif all (issymbollike, sym) && ! isnothing (A. sc)
163+ if all (Base. Fix1 (is_param_sym, A. sc), sym)
164+ return getindex .((A,), sym)
150165 else
151- observed (A, sym,:)
166+ return [ getindex .((A,), sym, i) for i in eachindex (A . t)]
152167 end
153168 else
154- Base . getindex .(A. u, i )
169+ return getindex .(A. u, sym )
155170 end
156171end
157172Base. @propagate_inbounds function Base. getindex (A:: AbstractDiffEqArray{T, N} ,sym,args... ) where {T, N}
158- if issymbollike (sym) && A. syms != = nothing
159- i = findfirst (isequal (Symbol (sym)),A. syms)
160- else
161- i = sym
162- end
163-
164- if i === nothing
165- if issymbollike (sym) && A. indepsym != = nothing && Symbol (sym) == A. indepsym
166- A. t[args... ]
173+ if issymbollike (sym) && ! isnothing (A. sc)
174+ if is_indep_sym (A. sc, sym)
175+ return A. t[args... ]
176+ elseif is_state_sym (A. sc, sym)
177+ return A[sym][args... ]
167178 else
168- observed (A,sym,args... )
179+ return observed (A, sym, args... )
169180 end
181+ elseif all (issymbollike, sym) && ! isnothing (A. sc)
182+ return reduce (vcat, map (s -> A[s, args... ]' , sym))
170183 else
171- Base . getindex .(A. u, i, args ... )
184+ return getindex .(A. u, sym )
172185 end
173186end
174187Base. @propagate_inbounds Base. getindex (A:: AbstractDiffEqArray{T, N} , I:: Int... ) where {T, N} = A. u[I[end ]][Base. front (I)... ]
@@ -230,8 +243,7 @@ tuples(VA::DiffEqArray) = tuple.(VA.t,VA.u)
230243Base. copy (VA:: AbstractDiffEqArray ) = typeof (VA)(
231244 copy (VA. u),
232245 copy (VA. t),
233- (VA. syms=== nothing ) ? nothing : copy (VA. syms),
234- (VA. indepsym=== nothing ) ? nothing : copy (VA. indepsym),
246+ (VA. sc=== nothing ) ? nothing : copy (VA. sc),
235247 (VA. observed=== nothing ) ? nothing : copy (VA. observed),
236248 (VA. p=== nothing ) ? nothing : copy (VA. p)
237249 )
0 commit comments