Skip to content

Commit 6ef1840

Browse files
authored
Remove specializations for BlasReal of wsum (#956)
1 parent aef7422 commit 6ef1840

File tree

1 file changed

+0
-92
lines changed

1 file changed

+0
-92
lines changed

src/weights.jl

Lines changed: 0 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -435,84 +435,6 @@ end
435435
# that keeps a local accumulator will be used when dim = 1.
436436
#
437437
# The internal function that implements this is _wsum_general!
438-
#
439-
# 3. _wsum! is specialized for following cases:
440-
# (a) A is a vector: we invoke the vector version wsum above.
441-
# The internal function that implements this is _wsum1!
442-
#
443-
# (b) A is a dense matrix with eltype <: BlasReal: we call gemv!
444-
# The internal function that implements this is _wsum2_blas!
445-
#
446-
# (c) A is a contiguous array with eltype <: BlasReal:
447-
# dim == 1: treat A like a matrix of size (d1, d2 x ... x dN)
448-
# dim == N: treat A like a matrix of size (d1 x ... x d(N-1), dN)
449-
# otherwise: decompose A into multiple pages, and apply _wsum2!
450-
# for each
451-
#
452-
# (d) A is a general dense array with eltype <: BlasReal:
453-
# dim <= 2: delegate to (a) and (b)
454-
# otherwise, decompose A into multiple pages
455-
456-
function _wsum1!(R::AbstractArray, A::AbstractVector, w::AbstractVector, init::Bool)
457-
r = wsum(A, w)
458-
if init
459-
R[1] = r
460-
else
461-
R[1] += r
462-
end
463-
return R
464-
end
465-
466-
function _wsum2_blas!(R::StridedVector{T}, A::StridedMatrix{T}, w::StridedVector{T}, dim::Int, init::Bool) where T<:BlasReal
467-
beta = ifelse(init, zero(T), one(T))
468-
trans = dim == 1 ? 'T' : 'N'
469-
BLAS.gemv!(trans, one(T), A, w, beta, R)
470-
return R
471-
end
472-
473-
function _wsumN!(R::StridedArray{T}, A::StridedArray{T,N}, w::StridedVector{T}, dim::Int, init::Bool) where {T<:BlasReal,N}
474-
if dim == 1
475-
m = size(A, 1)
476-
n = div(length(A), m)
477-
_wsum2_blas!(view(R,:), reshape(A, (m, n)), w, 1, init)
478-
elseif dim == N
479-
n = size(A, N)
480-
m = div(length(A), n)
481-
_wsum2_blas!(view(R,:), reshape(A, (m, n)), w, 2, init)
482-
else # 1 < dim < N
483-
m = 1
484-
for i = 1:dim-1; m *= size(A, i); end
485-
n = size(A, dim)
486-
k = 1
487-
for i = dim+1:N; k *= size(A, i); end
488-
Av = reshape(A, (m, n, k))
489-
Rv = reshape(R, (m, k))
490-
for i = 1:k
491-
_wsum2_blas!(view(Rv,:,i), view(Av,:,:,i), w, 2, init)
492-
end
493-
end
494-
return R
495-
end
496-
497-
function _wsumN!(R::StridedArray{T}, A::DenseArray{T,N}, w::StridedVector{T}, dim::Int, init::Bool) where {T<:BlasReal,N}
498-
@assert N >= 3
499-
if dim <= 2
500-
m = size(A, 1)
501-
n = size(A, 2)
502-
npages = 1
503-
for i = 3:N
504-
npages *= size(A, i)
505-
end
506-
rlen = ifelse(dim == 1, n, m)
507-
Rv = reshape(R, (rlen, npages))
508-
for i = 1:npages
509-
_wsum2_blas!(view(Rv,:,i), view(A,:,:,i), w, dim, init)
510-
end
511-
else
512-
_wsum_general!(R, identity, A, w, dim, init)
513-
end
514-
return R
515-
end
516438

517439
## general Cartesian-based weighted sum across dimensions
518440

@@ -572,25 +494,12 @@ end
572494
end
573495
end
574496

575-
# N = 1
576-
_wsum!(R::StridedArray{T}, A::DenseArray{T,1}, w::StridedVector{T}, dim::Int, init::Bool) where {T<:BlasReal} =
577-
_wsum1!(R, A, w, init)
578-
579-
# N = 2
580-
_wsum!(R::StridedArray{T}, A::DenseArray{T,2}, w::StridedVector{T}, dim::Int, init::Bool) where {T<:BlasReal} =
581-
(_wsum2_blas!(view(R,:), A, w, dim, init); R)
582-
583-
# N >= 3
584-
_wsum!(R::StridedArray{T}, A::DenseArray{T,N}, w::StridedVector{T}, dim::Int, init::Bool) where {T<:BlasReal,N} =
585-
_wsumN!(R, A, w, dim, init)
586-
587497
_wsum!(R::AbstractArray, A::AbstractArray, w::AbstractVector, dim::Int, init::Bool) =
588498
_wsum_general!(R, identity, A, w, dim, init)
589499

590500
## wsum! and wsum
591501

592502
wsumtype(::Type{T}, ::Type{W}) where {T,W} = typeof(zero(T) * zero(W) + zero(T) * zero(W))
593-
wsumtype(::Type{T}, ::Type{T}) where {T<:BlasReal} = T
594503

595504
"""
596505
wsum!(R::AbstractArray, A::AbstractArray,
@@ -663,7 +572,6 @@ _mean!(R::AbstractArray, A::AbstractArray, w::AbstractWeights, dims::Int) =
663572
rmul!(Base.sum!(R, A, w, dims), inv(sum(w)))
664573

665574
wmeantype(::Type{T}, ::Type{W}) where {T,W} = typeof((zero(T)*zero(W) + zero(T)*zero(W)) / one(W))
666-
wmeantype(::Type{T}, ::Type{T}) where {T<:BlasReal} = T
667575

668576
"""
669577
mean(A::AbstractArray, w::AbstractWeights[, dims::Int])

0 commit comments

Comments
 (0)