diff --git a/ext/StaticArraysStatisticsExt.jl b/ext/StaticArraysStatisticsExt.jl index 6a665617..526a1ded 100644 --- a/ext/StaticArraysStatisticsExt.jl +++ b/ext/StaticArraysStatisticsExt.jl @@ -12,10 +12,10 @@ _mean_denom(a, ::Colon) = length(a) _mean_denom(a, dims::Int) = size(a, dims) _mean_denom(a, ::Val{D}) where {D} = size(a, D) -@inline mean(a::StaticArray; dims=:) = _reduce(+, a, dims) / _mean_denom(a, dims) -@inline mean(f::Function, a::StaticArray; dims=:) = _mapreduce(f, +, dims, _InitialValue(), Size(a), a) / _mean_denom(a, dims) +@inline mean(a::StaticArray; dims::D=:) where {D} = _reduce(+, a, dims) / _mean_denom(a, dims) +@inline mean(f::Function, a::StaticArray; dims::D=:) where {D} = _mapreduce(f, +, dims, _InitialValue(), Size(a), a) / _mean_denom(a, dims) -@inline function median(a::StaticArray; dims = :) +@inline function median(a::StaticArray; dims::D = :) where {D} if dims == Colon() median(vec(a)) else diff --git a/src/mapreduce.jl b/src/mapreduce.jl index 1e0e6e1b..899ec2ab 100644 --- a/src/mapreduce.jl +++ b/src/mapreduce.jl @@ -146,7 +146,7 @@ end ## mapreduce ## ############### -@inline function mapreduce(f, op, a::StaticArray, b::StaticArray...; dims=:, init = _InitialValue()) +@inline function mapreduce(f, op, a::StaticArray, b::StaticArray...; dims::D=:, init = _InitialValue()) where {D} _mapreduce(f, op, dims, init, same_size(a, b...), a, b...) end @@ -235,7 +235,7 @@ end ## reduce ## ############ -@inline reduce(op::R, a::StaticArray; dims = :, init = _InitialValue()) where {R} = +@inline reduce(op::R, a::StaticArray; dims::D = :, init = _InitialValue()) where {D, R} = _reduce(op, a, dims, init) # disambiguation @@ -249,7 +249,7 @@ reduce(::typeof(hcat), A::StaticArray{<:Tuple,<:AbstractVecOrMat}) = reduce(::typeof(hcat), A::StaticArray{<:Tuple,<:StaticVecOrMatLike}) = _reduce(hcat, A, :, _InitialValue()) -@inline _reduce(op::R, a::StaticArray, dims, init = _InitialValue()) where {R} = +@inline _reduce(op::R, a::StaticArray, dims::D, init = _InitialValue()) where {D, R} = _mapreduce(identity, op, dims, init, Size(a), a) ################ @@ -264,7 +264,7 @@ reduce(::typeof(hcat), A::StaticArray{<:Tuple,<:StaticVecOrMatLike}) = _mapfoldl(f, op, :, init, Size(a), a) @inline foldl(op::R, a::StaticArray; init = _InitialValue()) where {R} = _foldl(op, a, :, init) -@inline _foldl(op::R, a, dims, init = _InitialValue()) where {R} = +@inline _foldl(op::R, a, dims::D, init = _InitialValue()) where {D, R} = _mapfoldl(identity, op, dims, init, Size(a), a) ####################### @@ -290,33 +290,33 @@ reduce(::typeof(hcat), A::StaticArray{<:Tuple,<:StaticVecOrMatLike}) = # TODO: change to use Base.reduce_empty/Base.reduce_first @inline iszero(a::StaticArray{<:Tuple,T}) where {T} = reduce((x,y) -> x && iszero(y), a, init=true) -@inline sum(a::StaticArray{<:Tuple,T}; dims=:, init=_InitialValue()) where {T} = _reduce(+, a, dims, init) -@inline sum(f, a::StaticArray{<:Tuple,T}; dims=:, init=_InitialValue()) where {T} = _mapreduce(f, +, dims, init, Size(a), a) -@inline sum(f::Union{Function, Type}, a::StaticArray{<:Tuple,T}; dims=:, init=_InitialValue()) where {T} = _mapreduce(f, +, dims, init, Size(a), a) # avoid ambiguity +@inline sum(a::StaticArray{<:Tuple,T}; dims::D=:, init=_InitialValue()) where {D, T} = _reduce(+, a, dims, init) +@inline sum(f, a::StaticArray{<:Tuple,T}; dims::D=:, init=_InitialValue()) where {D, T} = _mapreduce(f, +, dims, init, Size(a), a) +@inline sum(f::Union{Function, Type}, a::StaticArray{<:Tuple,T}; dims::D=:, init=_InitialValue()) where {D, T} = _mapreduce(f, +, dims, init, Size(a), a) # avoid ambiguity -@inline prod(a::StaticArray{<:Tuple,T}; dims=:, init=_InitialValue()) where {T} = _reduce(*, a, dims, init) -@inline prod(f, a::StaticArray{<:Tuple,T}; dims=:, init=_InitialValue()) where {T} = _mapreduce(f, *, dims, init, Size(a), a) -@inline prod(f::Union{Function, Type}, a::StaticArray{<:Tuple,T}; dims=:, init=_InitialValue()) where {T} = _mapreduce(f, *, dims, init, Size(a), a) +@inline prod(a::StaticArray{<:Tuple,T}; dims::D=:, init=_InitialValue()) where {D, T} = _reduce(*, a, dims, init) +@inline prod(f, a::StaticArray{<:Tuple,T}; dims::D=:, init=_InitialValue()) where {D, T} = _mapreduce(f, *, dims, init, Size(a), a) +@inline prod(f::Union{Function, Type}, a::StaticArray{<:Tuple,T}; dims::D=:, init=_InitialValue()) where {D, T} = _mapreduce(f, *, dims, init, Size(a), a) -@inline count(a::StaticArray{<:Tuple,Bool}; dims=:, init=0) = _reduce(+, a, dims, init) -@inline count(f, a::StaticArray; dims=:, init=0) = _mapreduce(x->f(x)::Bool, +, dims, init, Size(a), a) +@inline count(a::StaticArray{<:Tuple,Bool}; dims::D=:, init=0) where {D} = _reduce(+, a, dims, init) +@inline count(f, a::StaticArray; dims::D=:, init=0) where {D} = _mapreduce(x->f(x)::Bool, +, dims, init, Size(a), a) -@inline all(a::StaticArray{<:Tuple,Bool}; dims=:) = _reduce(&, a, dims, true) # non-branching versions -@inline all(f::Function, a::StaticArray; dims=:) = _mapreduce(x->f(x)::Bool, &, dims, true, Size(a), a) +@inline all(a::StaticArray{<:Tuple,Bool}; dims::D=:) where {D} = _reduce(&, a, dims, true) # non-branching versions +@inline all(f::Function, a::StaticArray; dims::D=:) where {D} = _mapreduce(x->f(x)::Bool, &, dims, true, Size(a), a) -@inline any(a::StaticArray{<:Tuple,Bool}; dims=:) = _reduce(|, a, dims, false) # (benchmarking needed) -@inline any(f::Function, a::StaticArray; dims=:) = _mapreduce(x->f(x)::Bool, |, dims, false, Size(a), a) # (benchmarking needed) +@inline any(a::StaticArray{<:Tuple,Bool}; dims::D=:) where {D} = _reduce(|, a, dims, false) # (benchmarking needed) +@inline any(f::Function, a::StaticArray; dims::D=:) where {D} = _mapreduce(x->f(x)::Bool, |, dims, false, Size(a), a) # (benchmarking needed) @inline Base.in(x, a::StaticArray) = _mapreduce(==(x), |, :, false, Size(a), a) -@inline minimum(a::StaticArray; dims=:) = _reduce(min, a, dims) # base has mapreduce(identity, scalarmin, a) -@inline minimum(f::Function, a::StaticArray; dims=:) = _mapreduce(f, min, dims, _InitialValue(), Size(a), a) +@inline minimum(a::StaticArray; dims::D=:) where {D} = _reduce(min, a, dims) # base has mapreduce(identity, scalarmin, a) +@inline minimum(f::Function, a::StaticArray; dims::D=:) where {D} = _mapreduce(f, min, dims, _InitialValue(), Size(a), a) -@inline maximum(a::StaticArray; dims=:) = _reduce(max, a, dims) # base has mapreduce(identity, scalarmax, a) -@inline maximum(f::Function, a::StaticArray; dims=:) = _mapreduce(f, max, dims, _InitialValue(), Size(a), a) +@inline maximum(a::StaticArray; dims::D=:) where {D} = _reduce(max, a, dims) # base has mapreduce(identity, scalarmax, a) +@inline maximum(f::Function, a::StaticArray; dims::D=:) where {D} = _mapreduce(f, max, dims, _InitialValue(), Size(a), a) # Diff is slightly different -@inline diff(a::StaticArray; dims) = _diff(Size(a), a, dims) +@inline diff(a::StaticArray; dims::D) where {D} = _diff(Size(a), a, dims) @inline diff(a::StaticVector) = diff(a;dims=Val(1)) @inline function _diff(sz::Size{S}, a::StaticArray, D::Int) where {S} @@ -343,16 +343,16 @@ end end _maybe_val(dims::Integer) = Val(Int(dims)) -_maybe_val(dims) = dims +_maybe_val(dims::D) where {D} = dims _valof(::Val{D}) where D = D -@inline Base.accumulate(op::F, a::StaticVector; dims = :, init = _InitialValue()) where {F} = +@inline Base.accumulate(op::F, a::StaticVector; dims::D = :, init = _InitialValue()) where {D, F} = _accumulate(op, a, _maybe_val(dims), init) -@inline Base.accumulate(op::F, a::StaticArray; dims, init = _InitialValue()) where {F} = +@inline Base.accumulate(op::F, a::StaticArray; dims::D, init = _InitialValue()) where {D, F} = _accumulate(op, a, _maybe_val(dims), init) -@inline function _accumulate(op::F, a::StaticArray, dims::Union{Val,Colon}, init) where {F} +@inline function _accumulate(op::F, a::StaticArray, dims::Dimensions, init) where {Dimensions <: Union{Val,Colon}, F} # Adjoin the initial value to `op` (one-line version of `Base.BottomRF`): rf(x, y) = x isa _InitialValue ? Base.reduce_first(op, y) : op(x, y)