Skip to content

Commit 1ecc966

Browse files
fix: VectorOfArray mapreduce, sum/prod performance, minor bugs
1 parent 873a072 commit 1ecc966

File tree

2 files changed

+73
-9
lines changed

2 files changed

+73
-9
lines changed

src/vector_of_array.jl

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,16 @@ function VectorOfArray(vec::AbstractVector{T}, ::NTuple{N}) where {T, N}
133133
VectorOfArray{eltype(T), N, typeof(vec)}(vec)
134134
end
135135
# Assume that the first element is representative of all other elements
136-
VectorOfArray(vec::AbstractVector) = VectorOfArray(vec, (size(vec[1])..., length(vec)))
136+
function VectorOfArray(vec::AbstractVector)
137+
T = eltype(vec[1])
138+
N = ndims(vec[1])
139+
if all(x isa Union{<:AbstractArray, <:AbstractVectorOfArray} for x in vec)
140+
A = Vector{Union{typeof.(vec)...}}
141+
else
142+
A = typeof(vec)
143+
end
144+
VectorOfArray{T, N + 1, A}(vec)
145+
end
137146
function VectorOfArray(vec::AbstractVector{VT}) where {T, N, VT <: AbstractArray{T, N}}
138147
VectorOfArray{T, N + 1, typeof(vec)}(vec)
139148
end
@@ -482,21 +491,30 @@ function Base.append!(VA::AbstractVectorOfArray{T, N},
482491
return VA
483492
end
484493

494+
function Base.stack(VA::AbstractVectorOfArray; dims = :)
495+
stack(VA.u; dims)
496+
end
497+
485498
# AbstractArray methods
486499
function Base.view(A::AbstractVectorOfArray, I::Vararg{Any,M}) where {M}
487500
@inline
488501
J = map(i->Base.unalias(A,i), to_indices(A, I))
489502
@boundscheck checkbounds(A, J...)
490503
SubArray(IndexStyle(A), A, J, Base.index_dimsum(J...))
491504
end
505+
function Base.SubArray(parent::AbstractVectorOfArray, indices::Tuple)
506+
@inline
507+
SubArray(IndexStyle(Base.viewindexing(indices), IndexStyle(parent)), parent, Base.ensure_indexable(indices), Base.index_dimsum(indices...))
508+
end
509+
Base.isassigned(VA::AbstractVectorOfArray, idxs...) = checkbounds(Bool, VA, idxs...)
492510
Base.check_parent_index_match(::RecursiveArrayTools.AbstractVectorOfArray{T,N}, ::NTuple{N,Bool}) where {T,N} = nothing
493511
Base.ndims(::AbstractVectorOfArray{T, N}) where {T, N} = N
494512
function Base.checkbounds(::Type{Bool}, VA::AbstractVectorOfArray, idx...)
495513
if checkbounds(Bool, VA.u, last(idx))
496514
if last(idx) isa Integer
497-
return all(checkbounds.(Bool, (VA.u[last(idx)],), Base.front(idx)))
515+
return all(checkbounds.(Bool, (VA.u[last(idx)],), Base.front(idx)...))
498516
else
499-
return all(checkbounds.(Bool, VA.u[last(idx)], Base.front(idx)))
517+
return all(checkbounds.(Bool, VA.u[last(idx)], tuple.(Base.front(idx))...))
500518
end
501519
end
502520
return false
@@ -595,10 +613,14 @@ function Base.convert(::Type{Array}, VA::AbstractVectorOfArray)
595613
end
596614

597615
# statistics
598-
@inline Base.sum(f, VA::AbstractVectorOfArray) = sum(f, Array(VA))
599-
@inline Base.sum(VA::AbstractVectorOfArray; kwargs...) = sum(Array(VA); kwargs...)
600-
@inline Base.prod(f, VA::AbstractVectorOfArray) = prod(f, Array(VA))
601-
@inline Base.prod(VA::AbstractVectorOfArray; kwargs...) = prod(Array(VA); kwargs...)
616+
@inline Base.sum(VA::AbstractVectorOfArray; kwargs...) = sum(identity, VA; kwargs...)
617+
@inline function Base.sum(f, VA::AbstractVectorOfArray; kwargs...)
618+
mapreduce(f, Base.add_sum, VA; kwargs...)
619+
end
620+
@inline Base.prod(VA::AbstractVectorOfArray; kwargs...) = prod(identity, VA; kwargs...)
621+
@inline function Base.prod(f, VA::AbstractVectorOfArray; kwargs...)
622+
mapreduce(f, Base.mul_prod, VA; kwargs...)
623+
end
602624

603625
@inline Statistics.mean(VA::AbstractVectorOfArray; kwargs...) = mean(Array(VA); kwargs...)
604626
@inline function Statistics.median(VA::AbstractVectorOfArray; kwargs...)
@@ -638,8 +660,12 @@ end
638660
end
639661

640662
Base.map(f, A::RecursiveArrayTools.AbstractVectorOfArray) = map(f, A.u)
641-
function Base.mapreduce(f, op, A::AbstractVectorOfArray)
642-
mapreduce(f, op, (mapreduce(f, op, x) for x in A.u))
663+
664+
function Base.mapreduce(f, op, A::AbstractVectorOfArray; kwargs...)
665+
mapreduce(f, op, view(A, ntuple(_ -> :, ndims(A))...); kwargs...)
666+
end
667+
function Base.mapreduce(f, op, A::AbstractVectorOfArray{T,1,<:AbstractVector{T}}; kwargs...) where {T}
668+
mapreduce(f, op, A.u; kwargs...)
643669
end
644670

645671
## broadcasting

test/interface_tests.jl

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,44 @@ push!(testda, [-1, -2, -3, -4])
5757
@test_throws MethodError push!(testda, [-1 -2 -3 -4])
5858
@test_throws MethodError push!(testda, [-1 -2; -3 -4])
5959

60+
# Type inference
61+
@inferred sum(testva)
62+
@inferred sum(VectorOfArray([VectorOfArray([zeros(4,4)])]))
63+
@inferred mapreduce(string, *, testva)
64+
65+
# mapreduce
66+
testva = VectorOfArray([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
67+
@test mapreduce(x -> string(x) * "q", *, testva) == "1q2q3q4q5q6q7q8q9q"
68+
69+
testvb = VectorOfArray([rand(1:10, 3, 3, 3) for _ in 1:4])
70+
arrvb = Array(testvb)
71+
for i in 1:ndims(arrvb)
72+
@test sum(arrvb; dims=i) == sum(testvb; dims=i)
73+
@test prod(arrvb; dims=i) == prod(testvb; dims=i)
74+
@test mapreduce(string, *, arrvb; dims=i) == mapreduce(string, *, testvb; dims=i)
75+
end
76+
77+
# Test when ndims == 1
78+
testvb = VectorOfArray(collect(1.0:0.1:2.0))
79+
arrvb = Array(testvb)
80+
@test sum(arrvb) == sum(testvb)
81+
@test prod(arrvb) == prod(testvb)
82+
@test mapreduce(string, *, arrvb) == mapreduce(string, *, testvb)
83+
84+
# view
85+
testvc = VectorOfArray([rand(1:10, 3, 3) for _ in 1:3])
86+
arrvc = Array(testvc)
87+
for idxs in [(2, 2, :), (2, :, 2), (:, 2, 2), (:, :, 2), (:, 2, :), (2, : ,:), (:, :, :)]
88+
arr_view = view(arrvc, idxs...)
89+
voa_view = view(testvc, idxs...)
90+
@test size(arr_view) == size(voa_view)
91+
@test all(arr_view .== voa_view)
92+
end
93+
94+
# test stack
95+
@test stack(testva) == [1 4 7; 2 5 8; 3 6 9]
96+
@test stack(testva; dims = 1) == [1 2 3; 4 5 6; 7 8 9]
97+
6098
# convert array from VectorOfArray/DiffEqArray
6199
t = 1:8
62100
recs = [rand(10, 7) for i in 1:8]

0 commit comments

Comments
 (0)