Skip to content

Commit be038d5

Browse files
Merge pull request #148 from sharanry/sy/add_DiffEqArray_special_case
Add special cases for DiffEqArray constructor
2 parents 81e1658 + 519e0be commit be038d5

File tree

3 files changed

+120
-13
lines changed

3 files changed

+120
-13
lines changed

src/vector_of_array.jl

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,12 @@ VectorOfArray(vec::AbstractVector{T}, ::NTuple{N}) where {T, N} = VectorOfArray{
3939
VectorOfArray(vec::AbstractVector) = VectorOfArray(vec, (size(vec[1])..., length(vec)))
4040
VectorOfArray(vec::AbstractVector{VT}) where {T, N, VT<:AbstractArray{T, N}} = VectorOfArray{T, N+1, typeof(vec)}(vec)
4141

42-
DiffEqArray(vec::AbstractVector{T}, ts, ::NTuple{N}) where {T, N} = DiffEqArray{eltype(T), N, typeof(vec), typeof(ts), Nothing, Nothing, Nothing, Nothing}(vec, ts, nothing, nothing, nothing, nothing)
42+
DiffEqArray(vec::AbstractVector{T}, ts, ::NTuple{N}, syms=nothing, indepsym=nothing, observed=nothing, p=nothing) where {T, N} = DiffEqArray{eltype(T), N, typeof(vec), typeof(ts), typeof(syms), typeof(indepsym), typeof(observed), typeof(p)}(vec, ts, syms, indepsym, observed, p)
4343
# Assume that the first element is representative of all other elements
44-
DiffEqArray(vec::AbstractVector,ts::AbstractVector) = DiffEqArray(vec, ts, (size(vec[1])..., length(vec)))
45-
DiffEqArray(vec::AbstractVector{VT},ts::AbstractVector) where {T, N, VT<:AbstractArray{T, N}} = DiffEqArray{T, N+1, typeof(vec), typeof(ts), Nothing, Nothing, Nothing, Nothing}(vec, ts, nothing, nothing, nothing, nothing)
46-
DiffEqArray(vec::AbstractVector{VT},ts::AbstractVector, syms::Vector{Symbol}, indepsym::Symbol, observed::Function, p) where {T, N, VT<:AbstractArray{T, N}} = DiffEqArray{T, N+1, typeof(vec), typeof(ts), typeof(syms), typeof(indepsym), typeof(observed), typeof(p)}(vec, ts, syms, indepsym, observed, p)
44+
DiffEqArray(vec::AbstractVector,ts::AbstractVector, syms=nothing, indepsym=nothing, observed=nothing, p=nothing) = DiffEqArray(vec, ts, (size(vec[1])..., length(vec)), syms, indepsym, observed, p)
45+
function DiffEqArray(vec::AbstractVector{VT},ts::AbstractVector, syms=nothing, indepsym=nothing, observed=nothing, p=nothing) where {T, N, VT<:AbstractArray{T, N}}
46+
DiffEqArray{T, N+1, typeof(vec), typeof(ts), typeof(syms), typeof(indepsym), typeof(observed), typeof(p)}(vec, ts, syms, indepsym, observed, p)
47+
end
4748

4849
# Interface for the linear indexing. This is just a view of the underlying nested structure
4950
@inline Base.firstindex(VA::AbstractVectorOfArray) = firstindex(VA.u)
@@ -102,6 +103,13 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractDiffEqArray{T, N},sym
102103
end
103104
Base.@propagate_inbounds Base.getindex(A::AbstractDiffEqArray{T, N}, I::Int...) where {T, N} = A.u[I[end]][Base.front(I)...]
104105
Base.@propagate_inbounds Base.getindex(A::AbstractDiffEqArray{T, N}, i::Int) where {T, N} = A.u[i]
106+
Base.@propagate_inbounds function Base.getindex(VA::AbstractDiffEqArray{T,N}, ii::CartesianIndex) where {T, N}
107+
ti = Tuple(ii)
108+
i = last(ti)
109+
jj = CartesianIndex(Base.front(ti))
110+
return VA.u[i][jj]
111+
end
112+
105113
function observed(A::AbstractDiffEqArray{T, N},sym,i::Int) where {T, N}
106114
A.observed(sym,A.u[i],A.p,A.t[i])
107115
end
@@ -149,6 +157,14 @@ end
149157
tuples(VA::DiffEqArray) = tuple.(VA.t,VA.u)
150158

151159
# Growing the array simply adds to the container vector
160+
Base.copy(VA::AbstractDiffEqArray) = typeof(VA)(
161+
copy(VA.u),
162+
copy(VA.t),
163+
(VA.syms===nothing) ? nothing : copy(VA.syms),
164+
(VA.indepsym===nothing) ? nothing : copy(VA.indepsym),
165+
(VA.observed===nothing) ? nothing : copy(VA.observed),
166+
(VA.p===nothing) ? nothing : copy(VA.p)
167+
)
152168
Base.copy(VA::AbstractVectorOfArray) = typeof(VA)(copy(VA.u))
153169
Base.sizehint!(VA::AbstractVectorOfArray{T, N}, i) where {T, N} = sizehint!(VA.u, i)
154170
Base.push!(VA::AbstractVectorOfArray{T, N}, new_item::AbstractVector) where {T, N} = push!(VA.u, new_item)

test/basic_indexing.jl

Lines changed: 61 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@ fill!(mulX, 0)
1515
mulX .= sqrt.(abs.(testva .* X))
1616
@test mulX == ref
1717

18-
t = [1,2,3]
19-
diffeq = DiffEqArray(recs,t)
2018
@test Array(testva) == [1 4 7
2119
2 5 8
2220
3 6 9]
@@ -25,7 +23,15 @@ diffeq = DiffEqArray(recs,t)
2523
@test testva[1:2, 1:2] == [1 4; 2 5]
2624
@test testa[1:2, 1:2] == [1 4; 2 5]
2725

26+
t = [1,2,3]
27+
diffeq = DiffEqArray(recs,t)
28+
@test Array(diffeq) == [1 4 7
29+
2 5 8
30+
3 6 9]
31+
@test diffeq[1:2, 1:2] == [1 4; 2 5]
32+
2833
# # ndims == 2
34+
t = 1:10
2935
recs = [rand(8) for i in 1:10]
3036
testa = cat(recs...,dims=2)
3137
testva = VectorOfArray(recs)
@@ -36,63 +42,100 @@ testva = VectorOfArray(recs)
3642
@test testva[end] == testa[:, end]
3743
@test testva[2:end] == VectorOfArray([recs[i] for i = 2:length(recs)])
3844

45+
diffeq = DiffEqArray(recs,t)
46+
@test diffeq[1] == testa[:, 1]
47+
@test diffeq[:] == recs
48+
@test diffeq[end] == testa[:, end]
49+
@test diffeq[2:end] == DiffEqArray([recs[i] for i = 2:length(recs)], t)
50+
3951
# ## (Int, Int)
4052
@test testa[5, 4] == testva[5, 4]
53+
@test testa[5, 4] == diffeq[5, 4]
4154

4255
# ## (Int, Range) or (Range, Int)
4356
@test testa[1, 2:3] == testva[1, 2:3]
4457
@test testa[5:end, 1] == testva[5:end, 1]
4558
@test testa[:, 1] == testva[:, 1]
4659
@test testa[3, :] == testva[3, :]
4760

61+
@test testa[1, 2:3] == diffeq[1, 2:3]
62+
@test testa[5:end, 1] == diffeq[5:end, 1]
63+
@test testa[:, 1] == diffeq[:, 1]
64+
@test testa[3, :] == diffeq[3, :]
65+
4866
# ## (Range, Range)
4967
@test testa[5:end, 1:2] == testva[5:end, 1:2]
68+
@test testa[5:end, 1:2] == diffeq[5:end, 1:2]
5069

5170
# # ndims == 3
71+
t = 1:15
5272
recs = recs = [rand(10, 8) for i in 1:15]
5373
testa = cat(recs...,dims=3)
5474
testva = VectorOfArray(recs)
75+
diffeq = DiffEqArray(recs,t)
5576

5677
# ## (Int, Int, Int)
5778
@test testa[1, 7, 14] == testva[1, 7, 14]
79+
@test testa[1, 7, 14] == diffeq[1, 7, 14]
5880

5981
# ## (Int, Int, Range)
6082
@test testa[2, 3, 1:2] == testva[2, 3, 1:2]
83+
@test testa[2, 3, 1:2] == diffeq[2, 3, 1:2]
6184

6285
# ## (Int, Range, Int)
6386
@test testa[2, 3:4, 1] == testva[2, 3:4, 1]
87+
@test testa[2, 3:4, 1] == diffeq[2, 3:4, 1]
6488

6589
# ## (Int, Range, Range)
6690
@test testa[2, 3:4, 1:2] == testva[2, 3:4, 1:2]
91+
@test testa[2, 3:4, 1:2] == diffeq[2, 3:4, 1:2]
6792

6893
# ## (Range, Int, Range)
6994
@test testa[2:3, 1, 1:2] == testva[2:3, 1, 1:2]
95+
@test testa[2:3, 1, 1:2] == diffeq[2:3, 1, 1:2]
7096

7197
# ## (Range, Range, Int)
7298
@test testa[1:2, 2:3, 1] == testva[1:2, 2:3, 1]
99+
@test testa[1:2, 2:3, 1] == diffeq[1:2, 2:3, 1]
73100

74101
# ## (Range, Range, Range)
75102
@test testa[2:3, 2:3, 1:2] == testva[2:3, 2:3, 1:2]
103+
@test testa[2:3, 2:3, 1:2] == diffeq[2:3, 2:3, 1:2]
76104

77105
# ## Make sure that 1:1 like ranges are not collapsed
78106
@test testa[1:1, 2:3, 1:2] == testva[1:1, 2:3, 1:2]
107+
@test testa[1:1, 2:3, 1:2] == diffeq[1:1, 2:3, 1:2]
79108

80109
# ## Test ragged arrays work, or give errors as needed
81110
#TODO: I am not really sure what the behavior of this is, what does Mathematica do?
111+
t = 1:3
82112
recs = [[1, 2, 3], [3, 5, 6, 7], [8, 9, 10, 11]]
83113
testva = VectorOfArray(recs) #TODO: clearly this printed form is nonsense
114+
diffeq = DiffEqArray(recs,t)
115+
84116
@test testva[:, 1] == recs[1]
85-
testva[1:2, 1:2]
117+
@test testva[1:2, 1:2] == [1 3; 2 5]
118+
@test diffeq[:, 1] == recs[1]
119+
@test diffeq[1:2, 1:2] == [1 3; 2 5]
86120

121+
t = 1:5
87122
recs = [rand(2,2) for i in 1:5]
88123
testva = VectorOfArray(recs)
124+
diffeq = DiffEqArray(recs,t)
125+
89126
@test Array(testva) isa Array{Float64,3}
127+
@test Array(diffeq) isa Array{Float64,3}
90128

91129
v = VectorOfArray([zeros(20), zeros(10,10), zeros(3,3,3)])
92130
v[CartesianIndex((2, 3, 2, 3))] = 1
93131
@test v[CartesianIndex((2, 3, 2, 3))] == 1
94132
@test v.u[3][2, 3, 2] == 1
95133

134+
v = DiffEqArray([zeros(20), zeros(10,10), zeros(3,3,3)], 1:3)
135+
v[CartesianIndex((2, 3, 2, 3))] = 1
136+
@test v[CartesianIndex((2, 3, 2, 3))] == 1
137+
@test v.u[3][2, 3, 2] == 1
138+
96139
v = VectorOfArray([rand(20), rand(10,10), rand(3,3,3)])
97140
w = v .* v
98141
@test w isa VectorOfArray
@@ -103,12 +146,25 @@ w = v .* v
103146
x = copy(v)
104147
x .= v .* v
105148
@test x.u == w.u
106-
107-
# broadcast with number
108149
w = v .+ 1
109150
@test w isa VectorOfArray
110151
@test w.u == map(x -> x .+ 1, v.u)
111152

153+
154+
v = DiffEqArray([rand(20), rand(10,10), rand(3,3,3)], 1:3)
155+
w = v .* v
156+
@test_broken w isa DiffEqArray # FIXME
157+
@test w[1] isa Vector
158+
@test w[1] == v[1] .* v[1]
159+
@test w[2] == v[2] .* v[2]
160+
@test w[3] == v[3] .* v[3]
161+
x = copy(v)
162+
x .= v .* v
163+
@test x.u == w.u
164+
w = v .+ 1
165+
@test_broken w isa DiffEqArray # FIXME
166+
@test w.u == map(x -> x .+ 1, v.u)
167+
112168
# edges cases
113169
x = [1, 2, 3, 4, 5, 6, 7, 8, 9]
114170
testva = DiffEqArray(x, x)

test/interface_tests.jl

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,51 +1,83 @@
11
using RecursiveArrayTools, Test
22

3-
recs = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
4-
testva = VectorOfArray(recs)
3+
t = 1:3
4+
testva = VectorOfArray([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
5+
testda = DiffEqArray([[1, 2, 3], [4, 5, 6], [7, 8, 9]],t)
56

67
for (i, elem) in enumerate(testva)
78
@test elem == testva[i]
89
end
910

11+
for (i, elem) in enumerate(testda)
12+
@test elem == testda[i]
13+
end
14+
1015
push!(testva, [10, 11, 12])
1116
@test testva[:, end] == [10, 11, 12]
17+
push!(testda, [10, 11, 12])
18+
@test testda[:, end] == [10, 11, 12]
19+
1220
testva2 = copy(testva)
1321
push!(testva2, [13, 14, 15])
22+
testda2 = copy(testva)
23+
push!(testda2, [13, 14, 15])
24+
1425
# make sure we copy when we pass containers
1526
@test size(testva) == (3, 4)
1627
@test testva2[:, end] == [13, 14, 15]
28+
@test size(testda) == (3, 4)
29+
@test testda2[:, end] == [13, 14, 15]
1730

1831
append!(testva, testva)
1932
@test testva[1:2, 5:6] == [1 4; 2 5]
33+
append!(testda, testda)
34+
@test testda[1:2, 5:6] == [1 4; 2 5]
2035

2136
# Test that adding a array of different dimension makes the array ragged
2237
push!(testva, [-1, -2, -3, -4])
38+
push!(testda, [-1, -2, -3, -4])
2339
#testva #TODO: this screws up printing, try to make a fallback
2440
@test testva[1:2, 5:6] == [1 4; 2 5] # we just let the indexing happen if it works
25-
testva[4, 9] # == testva.data[9][4]
41+
@test testda[1:2, 5:6] == [1 4; 2 5]
42+
2643
@test_throws BoundsError testva[4:5, 5:6]
44+
@test_throws BoundsError testda[4:5, 5:6]
45+
2746
@test testva[9] == [-1, -2, -3, -4]
2847
@test testva[end] == [-1, -2, -3, -4]
48+
@test testda[9] == [-1, -2, -3, -4]
49+
@test testda[end] == [-1, -2, -3, -4]
2950

3051
# Currently we enforce the general shape, they can just be different lengths, ie we
3152
# can't do
3253
# Decide if this is desired, or remove this restriction
3354
@test_throws MethodError push!(testva, [-1 -2 -3 -4])
3455
@test_throws MethodError push!(testva, [-1 -2; -3 -4])
56+
@test_throws MethodError push!(testda, [-1 -2 -3 -4])
57+
@test_throws MethodError push!(testda, [-1 -2; -3 -4])
3558

36-
# convert array from VectorOfArray
59+
# convert array from VectorOfArray/DiffEqArray
60+
t = 1:8
3761
recs = [rand(10, 7) for i = 1:8]
3862
testva = VectorOfArray(recs)
63+
testda = DiffEqArray(recs,t)
3964
testa = cat(recs...,dims=3)
65+
4066
@test convert(Array,testva) == testa
67+
@test convert(Array,testda) == testa
4168

69+
t = 1:3
4270
recs = [[1 2; 3 4], [3 5; 6 7], [8 9; 10 11]]
4371
testva = VectorOfArray(recs)
72+
testda = DiffEqArray(recs,t)
73+
4474
@test size(convert(Array,testva)) == (2,2,3)
75+
@test size(convert(Array,testda)) == (2,2,3)
4576

4677
# create similar VectorOfArray
4778
recs = [rand(6) for i = 1:4]
4879
testva = VectorOfArray(recs)
80+
4981
testva2 = similar(testva)
5082
@test typeof(testva2) == typeof(testva)
5183
@test size(testva2) == size(testva)
@@ -77,3 +109,6 @@ emptyda = DiffEqArray(Array{Vector{Float64}}([]), Vector{Float64}())
77109

78110
A = VectorOfArray(map(i->rand(2,4),1:7))
79111
@test map(x->maximum(x),A) isa Vector
112+
113+
DA = DiffEqArray(map(i->rand(2,4),1:7), 1:7)
114+
@test map(x->maximum(x),DA) isa Vector

0 commit comments

Comments
 (0)