From 7b4398ac71482c60cde04780eac62a8df8e898ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20M=C3=BCller-Widmann?= Date: Sat, 3 May 2025 16:24:36 +0200 Subject: [PATCH] Remove specializations for `BlasReal` of `wsum` --- src/weights.jl | 92 -------------------------------------------------- 1 file changed, 92 deletions(-) diff --git a/src/weights.jl b/src/weights.jl index d76470bc..1fc0766a 100644 --- a/src/weights.jl +++ b/src/weights.jl @@ -435,84 +435,6 @@ end # that keeps a local accumulator will be used when dim = 1. # # The internal function that implements this is _wsum_general! -# -# 3. _wsum! is specialized for following cases: -# (a) A is a vector: we invoke the vector version wsum above. -# The internal function that implements this is _wsum1! -# -# (b) A is a dense matrix with eltype <: BlasReal: we call gemv! -# The internal function that implements this is _wsum2_blas! -# -# (c) A is a contiguous array with eltype <: BlasReal: -# dim == 1: treat A like a matrix of size (d1, d2 x ... x dN) -# dim == N: treat A like a matrix of size (d1 x ... x d(N-1), dN) -# otherwise: decompose A into multiple pages, and apply _wsum2! -# for each -# -# (d) A is a general dense array with eltype <: BlasReal: -# dim <= 2: delegate to (a) and (b) -# otherwise, decompose A into multiple pages - -function _wsum1!(R::AbstractArray, A::AbstractVector, w::AbstractVector, init::Bool) - r = wsum(A, w) - if init - R[1] = r - else - R[1] += r - end - return R -end - -function _wsum2_blas!(R::StridedVector{T}, A::StridedMatrix{T}, w::StridedVector{T}, dim::Int, init::Bool) where T<:BlasReal - beta = ifelse(init, zero(T), one(T)) - trans = dim == 1 ? 'T' : 'N' - BLAS.gemv!(trans, one(T), A, w, beta, R) - return R -end - -function _wsumN!(R::StridedArray{T}, A::StridedArray{T,N}, w::StridedVector{T}, dim::Int, init::Bool) where {T<:BlasReal,N} - if dim == 1 - m = size(A, 1) - n = div(length(A), m) - _wsum2_blas!(view(R,:), reshape(A, (m, n)), w, 1, init) - elseif dim == N - n = size(A, N) - m = div(length(A), n) - _wsum2_blas!(view(R,:), reshape(A, (m, n)), w, 2, init) - else # 1 < dim < N - m = 1 - for i = 1:dim-1; m *= size(A, i); end - n = size(A, dim) - k = 1 - for i = dim+1:N; k *= size(A, i); end - Av = reshape(A, (m, n, k)) - Rv = reshape(R, (m, k)) - for i = 1:k - _wsum2_blas!(view(Rv,:,i), view(Av,:,:,i), w, 2, init) - end - end - return R -end - -function _wsumN!(R::StridedArray{T}, A::DenseArray{T,N}, w::StridedVector{T}, dim::Int, init::Bool) where {T<:BlasReal,N} - @assert N >= 3 - if dim <= 2 - m = size(A, 1) - n = size(A, 2) - npages = 1 - for i = 3:N - npages *= size(A, i) - end - rlen = ifelse(dim == 1, n, m) - Rv = reshape(R, (rlen, npages)) - for i = 1:npages - _wsum2_blas!(view(Rv,:,i), view(A,:,:,i), w, dim, init) - end - else - _wsum_general!(R, identity, A, w, dim, init) - end - return R -end ## general Cartesian-based weighted sum across dimensions @@ -572,25 +494,12 @@ end end end -# N = 1 -_wsum!(R::StridedArray{T}, A::DenseArray{T,1}, w::StridedVector{T}, dim::Int, init::Bool) where {T<:BlasReal} = - _wsum1!(R, A, w, init) - -# N = 2 -_wsum!(R::StridedArray{T}, A::DenseArray{T,2}, w::StridedVector{T}, dim::Int, init::Bool) where {T<:BlasReal} = - (_wsum2_blas!(view(R,:), A, w, dim, init); R) - -# N >= 3 -_wsum!(R::StridedArray{T}, A::DenseArray{T,N}, w::StridedVector{T}, dim::Int, init::Bool) where {T<:BlasReal,N} = - _wsumN!(R, A, w, dim, init) - _wsum!(R::AbstractArray, A::AbstractArray, w::AbstractVector, dim::Int, init::Bool) = _wsum_general!(R, identity, A, w, dim, init) ## wsum! and wsum wsumtype(::Type{T}, ::Type{W}) where {T,W} = typeof(zero(T) * zero(W) + zero(T) * zero(W)) -wsumtype(::Type{T}, ::Type{T}) where {T<:BlasReal} = T """ wsum!(R::AbstractArray, A::AbstractArray, @@ -663,7 +572,6 @@ _mean!(R::AbstractArray, A::AbstractArray, w::AbstractWeights, dims::Int) = rmul!(Base.sum!(R, A, w, dims), inv(sum(w))) wmeantype(::Type{T}, ::Type{W}) where {T,W} = typeof((zero(T)*zero(W) + zero(T)*zero(W)) / one(W)) -wmeantype(::Type{T}, ::Type{T}) where {T<:BlasReal} = T """ mean(A::AbstractArray, w::AbstractWeights[, dims::Int])