Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,12 @@ end
end
end

@inline function _mapreduce(f, op, D::Tuple, init, sz::Size{S}, a::StaticArray) where {S}
b = _mapreduce(f, op, first(D), init, sz, a)
return _mapreduce(f, op, Base.tail(D), init, Size(b), b)
end
_mapreduce(f, op, D::Tuple{}, init, sz::Size{S}, a::StaticArray) where {S} = a

@generated function _mapfoldl(f, op, dims::Val{D}, init,
::Size{S}, a::StaticArray) where {S,D}
N = length(S)
Expand Down
3 changes: 3 additions & 0 deletions test/mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ using Statistics: mean
RSArray1 = SArray{Tuple{1,J,K}} # reduced in dimension 1
RSArray2 = SArray{Tuple{I,1,K}} # reduced in dimension 2
RSArray3 = SArray{Tuple{I,J,1}} # reduced in dimension 3
RSArray13 = SArray{Tuple{1,J,1}} # reduced in dimension 1 and 3
a = randn(I,J,K); sa = OSArray(a)
b = rand(Bool,I,J,K); sb = OSArray(b)
z = zeros(I,J,K); sz = OSArray(z)
Expand All @@ -111,9 +112,11 @@ using Statistics: mean
@test sum(sa) === sum(a)
@test sum(abs2, sa) === sum(abs2, a)
@test sum(sa, dims=2) === RSArray2(sum(a, dims=2))
@test sum(sa, dims=(2,)) === RSArray2(sum(a, dims=2))
@test sum(sa, dims=Val(2)) === RSArray2(sum(a, dims=2))
@test sum(abs2, sa; dims=2) === RSArray2(sum(abs2, a, dims=2))
@test sum(abs2, sa; dims=Val(2)) === RSArray2(sum(abs2, a, dims=2))
@test_broken sum(abs2, sa; dims=(1,3)) === RSArray13(sum(abs2, a, dims=(1,3)))

@test prod(sa) === prod(a)
@test prod(abs2, sa) === prod(abs2, a)
Expand Down