Skip to content

Commit a153d86

Browse files
committed
Rearrange code and fix couple of bugs
1 parent 2c86c8c commit a153d86

File tree

1 file changed

+20
-23
lines changed

1 file changed

+20
-23
lines changed

src/vector_of_array.jl

Lines changed: 20 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,21 @@ mutable struct DiffEqArray{T, N, A, B, C, D, E, F} <: AbstractDiffEqArray{T, N,
1212
p::F
1313
end
1414

15+
Base.@pure __parameterless_type(T) = Base.typename(T).wrapper
16+
parameterless_type(x) = parameterless_type(typeof(x))
17+
parameterless_type(x::Type) = __parameterless_type(x)
18+
19+
### Abstract Interface
20+
struct AllObserved
21+
end
22+
issymbollike(x) = x isa Symbol ||
23+
x isa AllObserved ||
24+
Symbol(parameterless_type(typeof(x))) == :Operation ||
25+
Symbol(parameterless_type(typeof(x))) == :Variable ||
26+
Symbol(parameterless_type(typeof(x))) == :Sym ||
27+
Symbol(parameterless_type(typeof(x))) == :Num ||
28+
Symbol(parameterless_type(typeof(x))) == :Term
29+
1530
Base.Array(VA::AbstractVectorOfArray{T,N,A}) where {T,N,A <: AbstractVector{<:AbstractVector}} = reduce(hcat,VA.u)
1631
Base.Array(VA::AbstractVectorOfArray{T,N,A}) where {T,N,A <: AbstractVector{<:Number}} = VA.u
1732
function Base.Array(VA::AbstractVectorOfArray)
@@ -43,22 +58,6 @@ Base.@propagate_inbounds Base.getindex(VA::AbstractVectorOfArray{T, N}, I::Int)
4358
Base.@propagate_inbounds Base.getindex(VA::AbstractVectorOfArray{T, N}, I::Colon) where {T, N} = VA.u[I]
4459
Base.@propagate_inbounds Base.getindex(VA::AbstractVectorOfArray{T, N}, I::AbstractArray{Int}) where {T, N} = VectorOfArray(VA.u[I])
4560
Base.@propagate_inbounds Base.getindex(VA::AbstractDiffEqArray{T, N}, I::AbstractArray{Int}) where {T, N} = DiffEqArray(VA.u[I],VA.t[I])
46-
#
47-
Base.@pure __parameterless_type(T) = Base.typename(T).wrapper
48-
parameterless_type(x) = parameterless_type(typeof(x))
49-
parameterless_type(x::Type) = __parameterless_type(x)
50-
51-
### Abstract Interface
52-
struct AllObserved
53-
end
54-
issymbollike(x) = x isa Symbol ||
55-
x isa AllObserved ||
56-
Symbol(parameterless_type(typeof(x))) == :Operation ||
57-
Symbol(parameterless_type(typeof(x))) == :Variable ||
58-
Symbol(parameterless_type(typeof(x))) == :Sym ||
59-
Symbol(parameterless_type(typeof(x))) == :Num ||
60-
Symbol(parameterless_type(typeof(x))) == :Term
61-
6261
Base.@propagate_inbounds function Base.getindex(A::AbstractDiffEqArray{T, N},sym) where {T, N}
6362
if issymbollike(sym)
6463
i = findfirst(isequal(Symbol(sym)),A.syms)
@@ -74,7 +73,7 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractDiffEqArray{T, N},sym
7473
observed(A,sym,:)
7574
end
7675
else
77-
A[i,:]
76+
Base.getindex.(A.u, i)
7877
end
7978
end
8079
Base.@propagate_inbounds function Base.getindex(A::AbstractDiffEqArray{T, N},sym,args...) where {T, N}
@@ -92,23 +91,21 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractDiffEqArray{T, N},sym
9291
observed(A,sym,args...)
9392
end
9493
else
95-
A[i,args...]
94+
Base.getindex.(A.u, args...)
9695
end
9796
end
98-
Base.@propagate_inbounds Base.getindex(A::AbstractDiffEqArray{T, N} where {T, N}, i::Int64, ::Colon) = Base.getindex.(A.u, i)
99-
97+
Base.@propagate_inbounds Base.getindex(A::AbstractDiffEqArray{T, N}, I::Int...) where {T, N} = A.u[I[end]][Base.front(I)...]
98+
Base.@propagate_inbounds Base.getindex(A::AbstractDiffEqArray{T, N}, i::Int) where {T, N} = A.u[i]
10099
function observed(A::AbstractDiffEqArray{T, N},sym,i::Int) where {T, N}
101100
A.observed(sym,A.u[i],A.p,A.t[i])
102101
end
103-
104102
function observed(A::AbstractDiffEqArray{T, N},sym,i::AbstractArray{Int}) where {T, N}
105103
A.observed.((sym,),A.u[i],(A.p,),A.t[i])
106104
end
107-
108105
function observed(A::AbstractDiffEqArray{T, N},sym,::Colon) where {T, N}
109106
A.observed.((sym,),A.u,(A.p,),A.t)
110107
end
111-
#
108+
112109
Base.@propagate_inbounds Base.getindex(VA::AbstractVectorOfArray{T, N}, i::Int,::Colon) where {T, N} = [VA.u[j][i] for j in 1:length(VA)]
113110
Base.@propagate_inbounds function Base.getindex(VA::AbstractVectorOfArray{T,N}, ii::CartesianIndex) where {T, N}
114111
ti = Tuple(ii)

0 commit comments

Comments
 (0)