Skip to content

Commit 9239a6d

Browse files
committed
Only define BLAS optimizations for Float32 and Float64, not their union
1 parent 006584b commit 9239a6d

File tree

1 file changed

+73
-69
lines changed

1 file changed

+73
-69
lines changed

src/weights.jl

Lines changed: 73 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -440,16 +440,16 @@ end
440440
# (a) A is a vector: we invoke the vector version wsum above.
441441
# The internal function that implements this is _wsum1!
442442
#
443-
# (b) A is a dense matrix with eltype <: BlasReal: we call gemv!
443+
# (b) A is a dense matrix with `Float32` or `Float64` `eltype`: we call gemv!
444444
# The internal function that implements this is _wsum2_blas!
445445
#
446-
# (c) A is a contiguous array with eltype <: BlasReal:
446+
# (c) A is a contiguous array with `Float32` or `Float64` `eltype`:
447447
# dim == 1: treat A like a matrix of size (d1, d2 x ... x dN)
448448
# dim == N: treat A like a matrix of size (d1 x ... x d(N-1), dN)
449449
# otherwise: decompose A into multiple pages, and apply _wsum2!
450450
# for each
451451
#
452-
# (d) A is a general dense array with eltype <: BlasReal:
452+
# (d) A is a general dense array with `Float32` or `Float64` `eltype`:
453453
# dim <= 2: delegate to (a) and (b)
454454
# otherwise, decompose A into multiple pages
455455

@@ -463,56 +463,6 @@ function _wsum1!(R::AbstractArray, A::AbstractVector, w::AbstractVector, init::B
463463
return R
464464
end
465465

466-
function _wsum2_blas!(R::StridedVector{T}, A::StridedMatrix{T}, w::StridedVector{T}, dim::Int, init::Bool) where T<:BlasReal
467-
trans = dim == 1 ? 'T' : 'N'
468-
BLAS.gemv!(trans, true, A, w, !init, R)
469-
return R
470-
end
471-
472-
function _wsumN!(R::StridedArray{T}, A::StridedArray{T,N}, w::StridedVector{T}, dim::Int, init::Bool) where {T<:BlasReal,N}
473-
if dim == 1
474-
m = size(A, 1)
475-
n = div(length(A), m)
476-
_wsum2_blas!(view(R,:), reshape(A, (m, n)), w, 1, init)
477-
elseif dim == N
478-
n = size(A, N)
479-
m = div(length(A), n)
480-
_wsum2_blas!(view(R,:), reshape(A, (m, n)), w, 2, init)
481-
else # 1 < dim < N
482-
m = 1
483-
for i = 1:dim-1; m *= size(A, i); end
484-
n = size(A, dim)
485-
k = 1
486-
for i = dim+1:N; k *= size(A, i); end
487-
Av = reshape(A, (m, n, k))
488-
Rv = reshape(R, (m, k))
489-
for i = 1:k
490-
_wsum2_blas!(view(Rv,:,i), view(Av,:,:,i), w, 2, init)
491-
end
492-
end
493-
return R
494-
end
495-
496-
function _wsumN!(R::StridedArray{T}, A::DenseArray{T,N}, w::StridedVector{T}, dim::Int, init::Bool) where {T<:BlasReal,N}
497-
@assert N >= 3
498-
if dim <= 2
499-
m = size(A, 1)
500-
n = size(A, 2)
501-
npages = 1
502-
for i = 3:N
503-
npages *= size(A, i)
504-
end
505-
rlen = ifelse(dim == 1, n, m)
506-
Rv = reshape(R, (rlen, npages))
507-
for i = 1:npages
508-
_wsum2_blas!(view(Rv,:,i), view(A,:,:,i), w, dim, init)
509-
end
510-
else
511-
_wsum_general!(R, identity, A, w, dim, init)
512-
end
513-
return R
514-
end
515-
516466
## general Cartesian-based weighted sum across dimensions
517467

518468
@generated function _wsum_general!(R::AbstractArray{RT}, f::supertype(typeof(abs)),
@@ -571,25 +521,10 @@ end
571521
end
572522
end
573523

574-
# N = 1
575-
_wsum!(R::StridedArray{T}, A::DenseArray{T,1}, w::StridedVector{T}, dim::Int, init::Bool) where {T<:BlasReal} =
576-
_wsum1!(R, A, w, init)
577-
578-
# N = 2
579-
_wsum!(R::StridedArray{T}, A::DenseArray{T,2}, w::StridedVector{T}, dim::Int, init::Bool) where {T<:BlasReal} =
580-
(_wsum2_blas!(view(R,:), A, w, dim, init); R)
581-
582-
# N >= 3
583-
_wsum!(R::StridedArray{T}, A::DenseArray{T,N}, w::StridedVector{T}, dim::Int, init::Bool) where {T<:BlasReal,N} =
584-
_wsumN!(R, A, w, dim, init)
585-
586524
_wsum!(R::AbstractArray, A::AbstractArray, w::AbstractVector, dim::Int, init::Bool) =
587525
_wsum_general!(R, identity, A, w, dim, init)
588526

589-
## wsum! and wsum
590-
591527
wsumtype(::Type{T}, ::Type{W}) where {T,W} = typeof(zero(T) * zero(W) + zero(T) * zero(W))
592-
wsumtype(::Type{T}, ::Type{T}) where {T<:BlasReal} = T
593528

594529
"""
595530
wsum!(R::AbstractArray, A::AbstractArray,
@@ -662,7 +597,6 @@ _mean!(R::AbstractArray, A::AbstractArray, w::AbstractWeights, dims::Int) =
662597
rmul!(Base.sum!(R, A, w, dims), inv(sum(w)))
663598

664599
wmeantype(::Type{T}, ::Type{W}) where {T,W} = typeof((zero(T)*zero(W) + zero(T)*zero(W)) / one(W))
665-
wmeantype(::Type{T}, ::Type{T}) where {T<:BlasReal} = T
666600

667601
"""
668602
mean(A::AbstractArray, w::AbstractWeights[, dims::Int])
@@ -692,6 +626,76 @@ function mean(A::AbstractArray, w::UnitWeights; dims::Union{Colon,Int}=:)
692626
return mean(A, dims=dims)
693627
end
694628

629+
##### BLAS optimizations for Float32 and Float64 #####
630+
631+
for T in (:Float32, :Float64)
632+
@eval begin
633+
function _wsum!(R::StridedArray{$T}, A::DenseArray{$T,N}, w::StridedVector{$T}, dim::Int, init::Bool) where {N}
634+
if N === 1
635+
_wsum1!(R, A, w, init)
636+
elseif N === 2
637+
_wsum2_blas!(view(R,:), A, w, dim, init)
638+
else # N >= 3
639+
_wsumN!(R, A, w, dim, init)
640+
end
641+
return R
642+
end
643+
644+
wsumtype(::Type{$T}, ::Type{$T}) = $T
645+
wmeantype(::Type{$T}, ::Type{$T}) = $T
646+
647+
function _wsum2_blas!(R::StridedVector{$T}, A::StridedMatrix{$T}, w::StridedVector{$T}, dim::Int, init::Bool)
648+
trans = dim == 1 ? 'T' : 'N'
649+
BLAS.gemv!(trans, true, A, w, !init, R)
650+
return R
651+
end
652+
653+
function _wsumN!(R::StridedArray{$T}, A::StridedArray{$T,N}, w::StridedVector{$T}, dim::Int, init::Bool) where {N}
654+
if dim == 1
655+
m = size(A, 1)
656+
n = div(length(A), m)
657+
_wsum2_blas!(view(R,:), reshape(A, (m, n)), w, 1, init)
658+
elseif dim == N
659+
n = size(A, N)
660+
m = div(length(A), n)
661+
_wsum2_blas!(view(R,:), reshape(A, (m, n)), w, 2, init)
662+
else # 1 < dim < N
663+
m = 1
664+
for i = 1:dim-1; m *= size(A, i); end
665+
n = size(A, dim)
666+
k = 1
667+
for i = dim+1:N; k *= size(A, i); end
668+
Av = reshape(A, (m, n, k))
669+
Rv = reshape(R, (m, k))
670+
for i = 1:k
671+
_wsum2_blas!(view(Rv,:,i), view(Av,:,:,i), w, 2, init)
672+
end
673+
end
674+
return R
675+
end
676+
677+
function _wsumN!(R::StridedArray{$T}, A::DenseArray{$T,N}, w::StridedVector{$T}, dim::Int, init::Bool) where {N}
678+
@assert N >= 3
679+
if dim <= 2
680+
m = size(A, 1)
681+
n = size(A, 2)
682+
npages = 1
683+
for i = 3:N
684+
npages *= size(A, i)
685+
end
686+
rlen = ifelse(dim == 1, n, m)
687+
Rv = reshape(R, (rlen, npages))
688+
for i = 1:npages
689+
_wsum2_blas!(view(Rv,:,i), view(A,:,:,i), w, dim, init)
690+
end
691+
else
692+
_wsum_general!(R, identity, A, w, dim, init)
693+
end
694+
return R
695+
end
696+
end
697+
end
698+
695699
##### Weighted quantile #####
696700

697701
"""

0 commit comments

Comments
 (0)