Skip to content

Commit d819cbc

Browse files
committed
Add support for arrays of StaticArrays and use Requires
Using Requires reduces package load time a bit when StaticArrays and StatsBase are not used.
1 parent 30af502 commit d819cbc

File tree

8 files changed

+57
-13
lines changed

8 files changed

+57
-13
lines changed

REQUIRE

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
11
julia 0.7
22
Requires
3-
StatsBase
43
UnsafeArrays

src/ArraysOfArrays.jl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,20 @@ __precompile__(true)
44

55
module ArraysOfArrays
66

7+
using Requires
78
using Statistics
8-
99
using UnsafeArrays
1010

11-
import StatsBase
12-
1311
include("util.jl")
1412
include("functions.jl")
1513
include("array_of_similar_arrays.jl")
1614
include("vector_of_arrays.jl")
1715

16+
17+
function __init__()
18+
@require StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" include("staticarrays_support.jl")
19+
@require StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" include("statsbase_support.jl")
20+
end
21+
22+
1823
end # module

src/array_of_similar_arrays.jl

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -371,7 +371,6 @@ Compute the sum of the elements vectors of `X`. Equivalent to `sum` of
371371
`flatview(X)` along the last dimension.
372372
"""
373373
Base.sum(X::AbstractVectorOfSimilarArrays{T,M}) where {T,M} = sum(flatview(X); dims = M + 1)
374-
Base.sum(X::AbstractVectorOfSimilarArrays{T,M}, w::StatsBase.AbstractWeights) where {T,M} = sum(flatview(X), w, M + 1)
375374

376375

377376
"""
@@ -383,8 +382,6 @@ Compute the mean of the elements vectors of `X`. Equivalent to `mean` of
383382
"""
384383
Statistics.mean(X::AbstractVectorOfSimilarArrays{T,M}) where {T,M} =
385384
mean(flatview(X); dims = M + 1)
386-
Statistics.mean(X::AbstractVectorOfSimilarArrays{T,M}, w::StatsBase.AbstractWeights) where {T,M} =
387-
mean(flatview(X), w, M + 1)
388385

389386

390387
"""
@@ -396,8 +393,6 @@ Compute the sample variance of the elements vectors of `X`. Equivalent to
396393
"""
397394
Statistics.var(X::AbstractVectorOfSimilarArrays{T,M}; corrected::Bool = true) where {T,M} =
398395
var(flatview(X); dims = M + 1, corrected = corrected)
399-
Statistics.var(X::AbstractVectorOfSimilarArrays{T,M}, w::StatsBase.AbstractWeights; corrected::Bool = true) where {T,M} =
400-
var(flatview(X), w, M + 1; corrected = corrected)
401396

402397

403398
"""
@@ -409,8 +404,6 @@ along `X`. Equivalent to `cov` of `flatview(X)` along dimension 2.
409404
"""
410405
Statistics.cov(X::AbstractVectorOfSimilarVectors; corrected::Bool = true) =
411406
cov(flatview(X); dims = 2, corrected = corrected)
412-
Statistics.cov(X::AbstractVectorOfSimilarVectors, w::StatsBase.AbstractWeights; corrected::Bool = true) =
413-
cov(flatview(X), w, 2; corrected = corrected)
414407

415408

416409
"""
@@ -422,5 +415,3 @@ Compute the Pearson correlation matrix between the elements of the elements of
422415
"""
423416
Statistics.cor(X::AbstractVectorOfSimilarVectors) =
424417
cor(flatview(X); dims = 2)
425-
Statistics.cor(X::AbstractVectorOfSimilarVectors, w::StatsBase.AbstractWeights) =
426-
cor(flatview(X), w, 2)

src/functions.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,12 @@ AbstractArray{<:AbstractArray{T,M},N}
5757
5858
View array `A` in as an `M`-dimensional array of `N`-dimensional arrays by
5959
wrapping it into an [`ArrayOfSimilarArrays`](@ref).
60+
61+
It's also possible to use a `StaticVector` of length `S` as the type of the
62+
inner arrays via
63+
64+
nestedview(A::AbstractArray{T}, ::Type{StaticArrays.SVector{S}})
65+
nestedview(A::AbstractArray{T}, ::Type{StaticArrays.SVector{S,T}})
6066
"""
6167
function nestedview end
6268
export nestedview

src/staticarrays_support.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# This file is a part of ArraysOfArrays.jl, licensed under the MIT License (MIT).
2+
3+
4+
@inline flatview(A::AbstractArray{SA,N}) where {S,T,M,N,SA<:StaticArrays.StaticArray{S,T,M}} =
5+
reshape(reinterpret(T, A), size(SA)..., size(A)...)
6+
7+
8+
@inline function nestedview(A::AbstractArray{T}, SA::Type{StaticArrays.SVector{S,T}}) where {T,S}
9+
size_A = size(A)
10+
size_A[1] == S || throw(DimensionMismatch("Length $S of static vector type does not match first dimension of array of size $size_A"))
11+
reshape(reinterpret(SA, A), _tail(size_A)...)
12+
end
13+
14+
@inline nestedview(A::AbstractArray{T}, ::Type{StaticArrays.SVector{S}}) where {T,S} =
15+
nestedview(A, StaticArrays.SVector{S,T})

src/statsbase_support.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# This file is a part of ArraysOfArrays.jl, licensed under the MIT License (MIT).
2+
3+
4+
Base.sum(X::AbstractVectorOfSimilarArrays{T,M}, w::StatsBase.AbstractWeights) where {T,M} = sum(flatview(X), w, M + 1)
5+
6+
Statistics.mean(X::AbstractVectorOfSimilarArrays{T,M}, w::StatsBase.AbstractWeights) where {T,M} =
7+
mean(flatview(X), w, M + 1)
8+
9+
Statistics.var(X::AbstractVectorOfSimilarArrays{T,M}, w::StatsBase.AbstractWeights; corrected::Bool = true) where {T,M} =
10+
var(flatview(X), w, M + 1; corrected = corrected)
11+
12+
Statistics.cov(X::AbstractVectorOfSimilarVectors, w::StatsBase.AbstractWeights; corrected::Bool = true) =
13+
cov(flatview(X), w, 2; corrected = corrected)
14+
15+
Statistics.cor(X::AbstractVectorOfSimilarVectors, w::StatsBase.AbstractWeights) =
16+
cor(flatview(X), w, 2)

test/REQUIRE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
ElasticArrays
2+
StaticArrays

test/functions.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
using ArraysOfArrays
44
using Test
55

6+
using StaticArrays
7+
68

79
@testset "functions" begin
810
function gen_nested()
@@ -14,6 +16,15 @@ using Test
1416
end
1517

1618

19+
@testset "flatview and nestedview" begin
20+
A = [(@SArray randn(3, 2, 4)) for i in 1:2, j in 1:2]
21+
@test @inferred(nestedview(flatview(A), Val(3))) == A
22+
23+
B = rand(3, 2, 4)
24+
@test @inferred(nestedview(flatview(B), SVector{3})) == @inferred(nestedview(B, Val(1)))
25+
end
26+
27+
1728
@testset "deepgetindex" begin
1829
A = gen_nested()
1930
@test @inferred(deepgetindex(A, 1, 2)) === A[1, 2]

0 commit comments

Comments
 (0)