Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 0 additions & 92 deletions src/weights.jl
Original file line number Diff line number Diff line change
Expand Up @@ -435,84 +435,6 @@ end
# that keeps a local accumulator will be used when dim = 1.
#
# The internal function that implements this is _wsum_general!
#
# 3. _wsum! is specialized for following cases:
# (a) A is a vector: we invoke the vector version wsum above.
# The internal function that implements this is _wsum1!
#
# (b) A is a dense matrix with eltype <: BlasReal: we call gemv!
# The internal function that implements this is _wsum2_blas!
#
# (c) A is a contiguous array with eltype <: BlasReal:
# dim == 1: treat A like a matrix of size (d1, d2 x ... x dN)
# dim == N: treat A like a matrix of size (d1 x ... x d(N-1), dN)
# otherwise: decompose A into multiple pages, and apply _wsum2!
# for each
#
# (d) A is a general dense array with eltype <: BlasReal:
# dim <= 2: delegate to (a) and (b)
# otherwise, decompose A into multiple pages

function _wsum1!(R::AbstractArray, A::AbstractVector, w::AbstractVector, init::Bool)
r = wsum(A, w)
if init
R[1] = r
else
R[1] += r
end
return R
end

function _wsum2_blas!(R::StridedVector{T}, A::StridedMatrix{T}, w::StridedVector{T}, dim::Int, init::Bool) where T<:BlasReal
beta = ifelse(init, zero(T), one(T))
trans = dim == 1 ? 'T' : 'N'
BLAS.gemv!(trans, one(T), A, w, beta, R)
return R
end

function _wsumN!(R::StridedArray{T}, A::StridedArray{T,N}, w::StridedVector{T}, dim::Int, init::Bool) where {T<:BlasReal,N}
if dim == 1
m = size(A, 1)
n = div(length(A), m)
_wsum2_blas!(view(R,:), reshape(A, (m, n)), w, 1, init)
elseif dim == N
n = size(A, N)
m = div(length(A), n)
_wsum2_blas!(view(R,:), reshape(A, (m, n)), w, 2, init)
else # 1 < dim < N
m = 1
for i = 1:dim-1; m *= size(A, i); end
n = size(A, dim)
k = 1
for i = dim+1:N; k *= size(A, i); end
Av = reshape(A, (m, n, k))
Rv = reshape(R, (m, k))
for i = 1:k
_wsum2_blas!(view(Rv,:,i), view(Av,:,:,i), w, 2, init)
end
end
return R
end

function _wsumN!(R::StridedArray{T}, A::DenseArray{T,N}, w::StridedVector{T}, dim::Int, init::Bool) where {T<:BlasReal,N}
@assert N >= 3
if dim <= 2
m = size(A, 1)
n = size(A, 2)
npages = 1
for i = 3:N
npages *= size(A, i)
end
rlen = ifelse(dim == 1, n, m)
Rv = reshape(R, (rlen, npages))
for i = 1:npages
_wsum2_blas!(view(Rv,:,i), view(A,:,:,i), w, dim, init)
end
else
_wsum_general!(R, identity, A, w, dim, init)
end
return R
end

## general Cartesian-based weighted sum across dimensions

Expand Down Expand Up @@ -572,25 +494,12 @@ end
end
end

# N = 1
_wsum!(R::StridedArray{T}, A::DenseArray{T,1}, w::StridedVector{T}, dim::Int, init::Bool) where {T<:BlasReal} =
_wsum1!(R, A, w, init)

# N = 2
_wsum!(R::StridedArray{T}, A::DenseArray{T,2}, w::StridedVector{T}, dim::Int, init::Bool) where {T<:BlasReal} =
(_wsum2_blas!(view(R,:), A, w, dim, init); R)

# N >= 3
_wsum!(R::StridedArray{T}, A::DenseArray{T,N}, w::StridedVector{T}, dim::Int, init::Bool) where {T<:BlasReal,N} =
_wsumN!(R, A, w, dim, init)

_wsum!(R::AbstractArray, A::AbstractArray, w::AbstractVector, dim::Int, init::Bool) =
_wsum_general!(R, identity, A, w, dim, init)

## wsum! and wsum

wsumtype(::Type{T}, ::Type{W}) where {T,W} = typeof(zero(T) * zero(W) + zero(T) * zero(W))
wsumtype(::Type{T}, ::Type{T}) where {T<:BlasReal} = T

"""
wsum!(R::AbstractArray, A::AbstractArray,
Expand Down Expand Up @@ -663,7 +572,6 @@ _mean!(R::AbstractArray, A::AbstractArray, w::AbstractWeights, dims::Int) =
rmul!(Base.sum!(R, A, w, dims), inv(sum(w)))

wmeantype(::Type{T}, ::Type{W}) where {T,W} = typeof((zero(T)*zero(W) + zero(T)*zero(W)) / one(W))
wmeantype(::Type{T}, ::Type{T}) where {T<:BlasReal} = T

"""
mean(A::AbstractArray, w::AbstractWeights[, dims::Int])
Expand Down
Loading