From 781c83a34b484c3d8633c16401874de4980fa259 Mon Sep 17 00:00:00 2001 From: Neven Sajko Date: Fri, 12 Sep 2025 17:01:15 +0200 Subject: [PATCH] force specialization of dims argument type where it may be `Colon` As discussed in the Performance Tips in the Manual, Julia avoids specializing calls on the type of `Function` arguments by default. The function `:`, `(:) isa Function`, is often used not as a callable, but as the special value for specifying the "full range" or "all dimensions". However in such cases we often forget to force specialization on the type of `:`. This change fixes that. See PR JuliaLang/julia#59474, which applies the same kind of change in Julia itself. NB: I don't have an example where this change helps for StaticArrays, however the eliminated invalidation in the linked JuliaLang/julia PR is proof that it does help in some cases. Finding such examples is difficult because the compiler is often able to achieve good results because of constant propagation. However constprop is often fragile, so it is better to avoid relying on it. For example, constprop through recursion is not even attempted by the Julia compiler. I believe this change should not cause any real-world compile time regression, as `:` is the only function that is valid as a dims argument. --- ext/StaticArraysStatisticsExt.jl | 6 ++-- src/mapreduce.jl | 50 ++++++++++++++++---------------- 2 files changed, 28 insertions(+), 28 deletions(-) diff --git a/ext/StaticArraysStatisticsExt.jl b/ext/StaticArraysStatisticsExt.jl index 6a665617..526a1ded 100644 --- a/ext/StaticArraysStatisticsExt.jl +++ b/ext/StaticArraysStatisticsExt.jl @@ -12,10 +12,10 @@ _mean_denom(a, ::Colon) = length(a) _mean_denom(a, dims::Int) = size(a, dims) _mean_denom(a, ::Val{D}) where {D} = size(a, D) -@inline mean(a::StaticArray; dims=:) = _reduce(+, a, dims) / _mean_denom(a, dims) -@inline mean(f::Function, a::StaticArray; dims=:) = _mapreduce(f, +, dims, _InitialValue(), Size(a), a) / _mean_denom(a, dims) +@inline mean(a::StaticArray; dims::D=:) where {D} = _reduce(+, a, dims) / _mean_denom(a, dims) +@inline mean(f::Function, a::StaticArray; dims::D=:) where {D} = _mapreduce(f, +, dims, _InitialValue(), Size(a), a) / _mean_denom(a, dims) -@inline function median(a::StaticArray; dims = :) +@inline function median(a::StaticArray; dims::D = :) where {D} if dims == Colon() median(vec(a)) else diff --git a/src/mapreduce.jl b/src/mapreduce.jl index 1e0e6e1b..899ec2ab 100644 --- a/src/mapreduce.jl +++ b/src/mapreduce.jl @@ -146,7 +146,7 @@ end ## mapreduce ## ############### -@inline function mapreduce(f, op, a::StaticArray, b::StaticArray...; dims=:, init = _InitialValue()) +@inline function mapreduce(f, op, a::StaticArray, b::StaticArray...; dims::D=:, init = _InitialValue()) where {D} _mapreduce(f, op, dims, init, same_size(a, b...), a, b...) end @@ -235,7 +235,7 @@ end ## reduce ## ############ -@inline reduce(op::R, a::StaticArray; dims = :, init = _InitialValue()) where {R} = +@inline reduce(op::R, a::StaticArray; dims::D = :, init = _InitialValue()) where {D, R} = _reduce(op, a, dims, init) # disambiguation @@ -249,7 +249,7 @@ reduce(::typeof(hcat), A::StaticArray{<:Tuple,<:AbstractVecOrMat}) = reduce(::typeof(hcat), A::StaticArray{<:Tuple,<:StaticVecOrMatLike}) = _reduce(hcat, A, :, _InitialValue()) -@inline _reduce(op::R, a::StaticArray, dims, init = _InitialValue()) where {R} = +@inline _reduce(op::R, a::StaticArray, dims::D, init = _InitialValue()) where {D, R} = _mapreduce(identity, op, dims, init, Size(a), a) ################ @@ -264,7 +264,7 @@ reduce(::typeof(hcat), A::StaticArray{<:Tuple,<:StaticVecOrMatLike}) = _mapfoldl(f, op, :, init, Size(a), a) @inline foldl(op::R, a::StaticArray; init = _InitialValue()) where {R} = _foldl(op, a, :, init) -@inline _foldl(op::R, a, dims, init = _InitialValue()) where {R} = +@inline _foldl(op::R, a, dims::D, init = _InitialValue()) where {D, R} = _mapfoldl(identity, op, dims, init, Size(a), a) ####################### @@ -290,33 +290,33 @@ reduce(::typeof(hcat), A::StaticArray{<:Tuple,<:StaticVecOrMatLike}) = # TODO: change to use Base.reduce_empty/Base.reduce_first @inline iszero(a::StaticArray{<:Tuple,T}) where {T} = reduce((x,y) -> x && iszero(y), a, init=true) -@inline sum(a::StaticArray{<:Tuple,T}; dims=:, init=_InitialValue()) where {T} = _reduce(+, a, dims, init) -@inline sum(f, a::StaticArray{<:Tuple,T}; dims=:, init=_InitialValue()) where {T} = _mapreduce(f, +, dims, init, Size(a), a) -@inline sum(f::Union{Function, Type}, a::StaticArray{<:Tuple,T}; dims=:, init=_InitialValue()) where {T} = _mapreduce(f, +, dims, init, Size(a), a) # avoid ambiguity +@inline sum(a::StaticArray{<:Tuple,T}; dims::D=:, init=_InitialValue()) where {D, T} = _reduce(+, a, dims, init) +@inline sum(f, a::StaticArray{<:Tuple,T}; dims::D=:, init=_InitialValue()) where {D, T} = _mapreduce(f, +, dims, init, Size(a), a) +@inline sum(f::Union{Function, Type}, a::StaticArray{<:Tuple,T}; dims::D=:, init=_InitialValue()) where {D, T} = _mapreduce(f, +, dims, init, Size(a), a) # avoid ambiguity -@inline prod(a::StaticArray{<:Tuple,T}; dims=:, init=_InitialValue()) where {T} = _reduce(*, a, dims, init) -@inline prod(f, a::StaticArray{<:Tuple,T}; dims=:, init=_InitialValue()) where {T} = _mapreduce(f, *, dims, init, Size(a), a) -@inline prod(f::Union{Function, Type}, a::StaticArray{<:Tuple,T}; dims=:, init=_InitialValue()) where {T} = _mapreduce(f, *, dims, init, Size(a), a) +@inline prod(a::StaticArray{<:Tuple,T}; dims::D=:, init=_InitialValue()) where {D, T} = _reduce(*, a, dims, init) +@inline prod(f, a::StaticArray{<:Tuple,T}; dims::D=:, init=_InitialValue()) where {D, T} = _mapreduce(f, *, dims, init, Size(a), a) +@inline prod(f::Union{Function, Type}, a::StaticArray{<:Tuple,T}; dims::D=:, init=_InitialValue()) where {D, T} = _mapreduce(f, *, dims, init, Size(a), a) -@inline count(a::StaticArray{<:Tuple,Bool}; dims=:, init=0) = _reduce(+, a, dims, init) -@inline count(f, a::StaticArray; dims=:, init=0) = _mapreduce(x->f(x)::Bool, +, dims, init, Size(a), a) +@inline count(a::StaticArray{<:Tuple,Bool}; dims::D=:, init=0) where {D} = _reduce(+, a, dims, init) +@inline count(f, a::StaticArray; dims::D=:, init=0) where {D} = _mapreduce(x->f(x)::Bool, +, dims, init, Size(a), a) -@inline all(a::StaticArray{<:Tuple,Bool}; dims=:) = _reduce(&, a, dims, true) # non-branching versions -@inline all(f::Function, a::StaticArray; dims=:) = _mapreduce(x->f(x)::Bool, &, dims, true, Size(a), a) +@inline all(a::StaticArray{<:Tuple,Bool}; dims::D=:) where {D} = _reduce(&, a, dims, true) # non-branching versions +@inline all(f::Function, a::StaticArray; dims::D=:) where {D} = _mapreduce(x->f(x)::Bool, &, dims, true, Size(a), a) -@inline any(a::StaticArray{<:Tuple,Bool}; dims=:) = _reduce(|, a, dims, false) # (benchmarking needed) -@inline any(f::Function, a::StaticArray; dims=:) = _mapreduce(x->f(x)::Bool, |, dims, false, Size(a), a) # (benchmarking needed) +@inline any(a::StaticArray{<:Tuple,Bool}; dims::D=:) where {D} = _reduce(|, a, dims, false) # (benchmarking needed) +@inline any(f::Function, a::StaticArray; dims::D=:) where {D} = _mapreduce(x->f(x)::Bool, |, dims, false, Size(a), a) # (benchmarking needed) @inline Base.in(x, a::StaticArray) = _mapreduce(==(x), |, :, false, Size(a), a) -@inline minimum(a::StaticArray; dims=:) = _reduce(min, a, dims) # base has mapreduce(identity, scalarmin, a) -@inline minimum(f::Function, a::StaticArray; dims=:) = _mapreduce(f, min, dims, _InitialValue(), Size(a), a) +@inline minimum(a::StaticArray; dims::D=:) where {D} = _reduce(min, a, dims) # base has mapreduce(identity, scalarmin, a) +@inline minimum(f::Function, a::StaticArray; dims::D=:) where {D} = _mapreduce(f, min, dims, _InitialValue(), Size(a), a) -@inline maximum(a::StaticArray; dims=:) = _reduce(max, a, dims) # base has mapreduce(identity, scalarmax, a) -@inline maximum(f::Function, a::StaticArray; dims=:) = _mapreduce(f, max, dims, _InitialValue(), Size(a), a) +@inline maximum(a::StaticArray; dims::D=:) where {D} = _reduce(max, a, dims) # base has mapreduce(identity, scalarmax, a) +@inline maximum(f::Function, a::StaticArray; dims::D=:) where {D} = _mapreduce(f, max, dims, _InitialValue(), Size(a), a) # Diff is slightly different -@inline diff(a::StaticArray; dims) = _diff(Size(a), a, dims) +@inline diff(a::StaticArray; dims::D) where {D} = _diff(Size(a), a, dims) @inline diff(a::StaticVector) = diff(a;dims=Val(1)) @inline function _diff(sz::Size{S}, a::StaticArray, D::Int) where {S} @@ -343,16 +343,16 @@ end end _maybe_val(dims::Integer) = Val(Int(dims)) -_maybe_val(dims) = dims +_maybe_val(dims::D) where {D} = dims _valof(::Val{D}) where D = D -@inline Base.accumulate(op::F, a::StaticVector; dims = :, init = _InitialValue()) where {F} = +@inline Base.accumulate(op::F, a::StaticVector; dims::D = :, init = _InitialValue()) where {D, F} = _accumulate(op, a, _maybe_val(dims), init) -@inline Base.accumulate(op::F, a::StaticArray; dims, init = _InitialValue()) where {F} = +@inline Base.accumulate(op::F, a::StaticArray; dims::D, init = _InitialValue()) where {D, F} = _accumulate(op, a, _maybe_val(dims), init) -@inline function _accumulate(op::F, a::StaticArray, dims::Union{Val,Colon}, init) where {F} +@inline function _accumulate(op::F, a::StaticArray, dims::Dimensions, init) where {Dimensions <: Union{Val,Colon}, F} # Adjoin the initial value to `op` (one-line version of `Base.BottomRF`): rf(x, y) = x isa _InitialValue ? Base.reduce_first(op, y) : op(x, y)