@@ -440,16 +440,16 @@ end
440440# (a) A is a vector: we invoke the vector version wsum above.
441441# The internal function that implements this is _wsum1!
442442#
443- # (b) A is a dense matrix with eltype <: BlasReal : we call gemv!
443+ # (b) A is a dense matrix with `Float32` or `Float64` `eltype` : we call gemv!
444444# The internal function that implements this is _wsum2_blas!
445445#
446- # (c) A is a contiguous array with eltype <: BlasReal :
446+ # (c) A is a contiguous array with `Float32` or `Float64` `eltype` :
447447# dim == 1: treat A like a matrix of size (d1, d2 x ... x dN)
448448# dim == N: treat A like a matrix of size (d1 x ... x d(N-1), dN)
449449# otherwise: decompose A into multiple pages, and apply _wsum2!
450450# for each
451451#
452- # (d) A is a general dense array with eltype <: BlasReal :
452+ # (d) A is a general dense array with `Float32` or `Float64` `eltype` :
453453# dim <= 2: delegate to (a) and (b)
454454# otherwise, decompose A into multiple pages
455455
@@ -463,56 +463,6 @@ function _wsum1!(R::AbstractArray, A::AbstractVector, w::AbstractVector, init::B
463463 return R
464464end
465465
466- function _wsum2_blas! (R:: StridedVector{T} , A:: StridedMatrix{T} , w:: StridedVector{T} , dim:: Int , init:: Bool ) where T<: BlasReal
467- trans = dim == 1 ? ' T' : ' N'
468- BLAS. gemv! (trans, true , A, w, ! init, R)
469- return R
470- end
471-
472- function _wsumN! (R:: StridedArray{T} , A:: StridedArray{T,N} , w:: StridedVector{T} , dim:: Int , init:: Bool ) where {T<: BlasReal ,N}
473- if dim == 1
474- m = size (A, 1 )
475- n = div (length (A), m)
476- _wsum2_blas! (view (R,:), reshape (A, (m, n)), w, 1 , init)
477- elseif dim == N
478- n = size (A, N)
479- m = div (length (A), n)
480- _wsum2_blas! (view (R,:), reshape (A, (m, n)), w, 2 , init)
481- else # 1 < dim < N
482- m = 1
483- for i = 1 : dim- 1 ; m *= size (A, i); end
484- n = size (A, dim)
485- k = 1
486- for i = dim+ 1 : N; k *= size (A, i); end
487- Av = reshape (A, (m, n, k))
488- Rv = reshape (R, (m, k))
489- for i = 1 : k
490- _wsum2_blas! (view (Rv,:,i), view (Av,:,:,i), w, 2 , init)
491- end
492- end
493- return R
494- end
495-
496- function _wsumN! (R:: StridedArray{T} , A:: DenseArray{T,N} , w:: StridedVector{T} , dim:: Int , init:: Bool ) where {T<: BlasReal ,N}
497- @assert N >= 3
498- if dim <= 2
499- m = size (A, 1 )
500- n = size (A, 2 )
501- npages = 1
502- for i = 3 : N
503- npages *= size (A, i)
504- end
505- rlen = ifelse (dim == 1 , n, m)
506- Rv = reshape (R, (rlen, npages))
507- for i = 1 : npages
508- _wsum2_blas! (view (Rv,:,i), view (A,:,:,i), w, dim, init)
509- end
510- else
511- _wsum_general! (R, identity, A, w, dim, init)
512- end
513- return R
514- end
515-
516466# # general Cartesian-based weighted sum across dimensions
517467
518468@generated function _wsum_general! (R:: AbstractArray{RT} , f:: supertype (typeof (abs)),
@@ -571,25 +521,10 @@ end
571521 end
572522end
573523
574- # N = 1
575- _wsum! (R:: StridedArray{T} , A:: DenseArray{T,1} , w:: StridedVector{T} , dim:: Int , init:: Bool ) where {T<: BlasReal } =
576- _wsum1! (R, A, w, init)
577-
578- # N = 2
579- _wsum! (R:: StridedArray{T} , A:: DenseArray{T,2} , w:: StridedVector{T} , dim:: Int , init:: Bool ) where {T<: BlasReal } =
580- (_wsum2_blas! (view (R,:), A, w, dim, init); R)
581-
582- # N >= 3
583- _wsum! (R:: StridedArray{T} , A:: DenseArray{T,N} , w:: StridedVector{T} , dim:: Int , init:: Bool ) where {T<: BlasReal ,N} =
584- _wsumN! (R, A, w, dim, init)
585-
586524_wsum! (R:: AbstractArray , A:: AbstractArray , w:: AbstractVector , dim:: Int , init:: Bool ) =
587525 _wsum_general! (R, identity, A, w, dim, init)
588526
589- # # wsum! and wsum
590-
591527wsumtype (:: Type{T} , :: Type{W} ) where {T,W} = typeof (zero (T) * zero (W) + zero (T) * zero (W))
592- wsumtype (:: Type{T} , :: Type{T} ) where {T<: BlasReal } = T
593528
594529"""
595530 wsum!(R::AbstractArray, A::AbstractArray,
@@ -662,7 +597,6 @@ _mean!(R::AbstractArray, A::AbstractArray, w::AbstractWeights, dims::Int) =
662597 rmul! (Base. sum! (R, A, w, dims), inv (sum (w)))
663598
664599wmeantype (:: Type{T} , :: Type{W} ) where {T,W} = typeof ((zero (T)* zero (W) + zero (T)* zero (W)) / one (W))
665- wmeantype (:: Type{T} , :: Type{T} ) where {T<: BlasReal } = T
666600
667601"""
668602 mean(A::AbstractArray, w::AbstractWeights[, dims::Int])
@@ -692,6 +626,76 @@ function mean(A::AbstractArray, w::UnitWeights; dims::Union{Colon,Int}=:)
692626 return mean (A, dims= dims)
693627end
694628
629+ # #### BLAS optimizations for Float32 and Float64 #####
630+
631+ for T in (:Float32 , :Float64 )
632+ @eval begin
633+ function _wsum! (R:: StridedArray{$T} , A:: DenseArray{$T,N} , w:: StridedVector{$T} , dim:: Int , init:: Bool ) where {N}
634+ if N === 1
635+ _wsum1! (R, A, w, init)
636+ elseif N === 2
637+ _wsum2_blas! (view (R,:), A, w, dim, init)
638+ else # N >= 3
639+ _wsumN! (R, A, w, dim, init)
640+ end
641+ return R
642+ end
643+
644+ wsumtype (:: Type{$T} , :: Type{$T} ) = $ T
645+ wmeantype (:: Type{$T} , :: Type{$T} ) = $ T
646+
647+ function _wsum2_blas! (R:: StridedVector{$T} , A:: StridedMatrix{$T} , w:: StridedVector{$T} , dim:: Int , init:: Bool )
648+ trans = dim == 1 ? ' T' : ' N'
649+ BLAS. gemv! (trans, true , A, w, ! init, R)
650+ return R
651+ end
652+
653+ function _wsumN! (R:: StridedArray{$T} , A:: StridedArray{$T,N} , w:: StridedVector{$T} , dim:: Int , init:: Bool ) where {N}
654+ if dim == 1
655+ m = size (A, 1 )
656+ n = div (length (A), m)
657+ _wsum2_blas! (view (R,:), reshape (A, (m, n)), w, 1 , init)
658+ elseif dim == N
659+ n = size (A, N)
660+ m = div (length (A), n)
661+ _wsum2_blas! (view (R,:), reshape (A, (m, n)), w, 2 , init)
662+ else # 1 < dim < N
663+ m = 1
664+ for i = 1 : dim- 1 ; m *= size (A, i); end
665+ n = size (A, dim)
666+ k = 1
667+ for i = dim+ 1 : N; k *= size (A, i); end
668+ Av = reshape (A, (m, n, k))
669+ Rv = reshape (R, (m, k))
670+ for i = 1 : k
671+ _wsum2_blas! (view (Rv,:,i), view (Av,:,:,i), w, 2 , init)
672+ end
673+ end
674+ return R
675+ end
676+
677+ function _wsumN! (R:: StridedArray{$T} , A:: DenseArray{$T,N} , w:: StridedVector{$T} , dim:: Int , init:: Bool ) where {N}
678+ @assert N >= 3
679+ if dim <= 2
680+ m = size (A, 1 )
681+ n = size (A, 2 )
682+ npages = 1
683+ for i = 3 : N
684+ npages *= size (A, i)
685+ end
686+ rlen = ifelse (dim == 1 , n, m)
687+ Rv = reshape (R, (rlen, npages))
688+ for i = 1 : npages
689+ _wsum2_blas! (view (Rv,:,i), view (A,:,:,i), w, dim, init)
690+ end
691+ else
692+ _wsum_general! (R, identity, A, w, dim, init)
693+ end
694+ return R
695+ end
696+ end
697+ end
698+
695699# #### Weighted quantile #####
696700
697701"""
0 commit comments