|
435 | 435 | # that keeps a local accumulator will be used when dim = 1. |
436 | 436 | # |
437 | 437 | # The internal function that implements this is _wsum_general! |
438 | | -# |
439 | | -# 3. _wsum! is specialized for following cases: |
440 | | -# (a) A is a vector: we invoke the vector version wsum above. |
441 | | -# The internal function that implements this is _wsum1! |
442 | | -# |
443 | | -# (b) A is a dense matrix with eltype <: BlasReal: we call gemv! |
444 | | -# The internal function that implements this is _wsum2_blas! |
445 | | -# |
446 | | -# (c) A is a contiguous array with eltype <: BlasReal: |
447 | | -# dim == 1: treat A like a matrix of size (d1, d2 x ... x dN) |
448 | | -# dim == N: treat A like a matrix of size (d1 x ... x d(N-1), dN) |
449 | | -# otherwise: decompose A into multiple pages, and apply _wsum2! |
450 | | -# for each |
451 | | -# |
452 | | -# (d) A is a general dense array with eltype <: BlasReal: |
453 | | -# dim <= 2: delegate to (a) and (b) |
454 | | -# otherwise, decompose A into multiple pages |
455 | | - |
456 | | -function _wsum1!(R::AbstractArray, A::AbstractVector, w::AbstractVector, init::Bool) |
457 | | - r = wsum(A, w) |
458 | | - if init |
459 | | - R[1] = r |
460 | | - else |
461 | | - R[1] += r |
462 | | - end |
463 | | - return R |
464 | | -end |
465 | | - |
466 | | -function _wsum2_blas!(R::StridedVector{T}, A::StridedMatrix{T}, w::StridedVector{T}, dim::Int, init::Bool) where T<:BlasReal |
467 | | - beta = ifelse(init, zero(T), one(T)) |
468 | | - trans = dim == 1 ? 'T' : 'N' |
469 | | - BLAS.gemv!(trans, one(T), A, w, beta, R) |
470 | | - return R |
471 | | -end |
472 | | - |
473 | | -function _wsumN!(R::StridedArray{T}, A::StridedArray{T,N}, w::StridedVector{T}, dim::Int, init::Bool) where {T<:BlasReal,N} |
474 | | - if dim == 1 |
475 | | - m = size(A, 1) |
476 | | - n = div(length(A), m) |
477 | | - _wsum2_blas!(view(R,:), reshape(A, (m, n)), w, 1, init) |
478 | | - elseif dim == N |
479 | | - n = size(A, N) |
480 | | - m = div(length(A), n) |
481 | | - _wsum2_blas!(view(R,:), reshape(A, (m, n)), w, 2, init) |
482 | | - else # 1 < dim < N |
483 | | - m = 1 |
484 | | - for i = 1:dim-1; m *= size(A, i); end |
485 | | - n = size(A, dim) |
486 | | - k = 1 |
487 | | - for i = dim+1:N; k *= size(A, i); end |
488 | | - Av = reshape(A, (m, n, k)) |
489 | | - Rv = reshape(R, (m, k)) |
490 | | - for i = 1:k |
491 | | - _wsum2_blas!(view(Rv,:,i), view(Av,:,:,i), w, 2, init) |
492 | | - end |
493 | | - end |
494 | | - return R |
495 | | -end |
496 | | - |
497 | | -function _wsumN!(R::StridedArray{T}, A::DenseArray{T,N}, w::StridedVector{T}, dim::Int, init::Bool) where {T<:BlasReal,N} |
498 | | - @assert N >= 3 |
499 | | - if dim <= 2 |
500 | | - m = size(A, 1) |
501 | | - n = size(A, 2) |
502 | | - npages = 1 |
503 | | - for i = 3:N |
504 | | - npages *= size(A, i) |
505 | | - end |
506 | | - rlen = ifelse(dim == 1, n, m) |
507 | | - Rv = reshape(R, (rlen, npages)) |
508 | | - for i = 1:npages |
509 | | - _wsum2_blas!(view(Rv,:,i), view(A,:,:,i), w, dim, init) |
510 | | - end |
511 | | - else |
512 | | - _wsum_general!(R, identity, A, w, dim, init) |
513 | | - end |
514 | | - return R |
515 | | -end |
516 | 438 |
|
517 | 439 | ## general Cartesian-based weighted sum across dimensions |
518 | 440 |
|
@@ -572,25 +494,12 @@ end |
572 | 494 | end |
573 | 495 | end |
574 | 496 |
|
575 | | -# N = 1 |
576 | | -_wsum!(R::StridedArray{T}, A::DenseArray{T,1}, w::StridedVector{T}, dim::Int, init::Bool) where {T<:BlasReal} = |
577 | | - _wsum1!(R, A, w, init) |
578 | | - |
579 | | -# N = 2 |
580 | | -_wsum!(R::StridedArray{T}, A::DenseArray{T,2}, w::StridedVector{T}, dim::Int, init::Bool) where {T<:BlasReal} = |
581 | | - (_wsum2_blas!(view(R,:), A, w, dim, init); R) |
582 | | - |
583 | | -# N >= 3 |
584 | | -_wsum!(R::StridedArray{T}, A::DenseArray{T,N}, w::StridedVector{T}, dim::Int, init::Bool) where {T<:BlasReal,N} = |
585 | | - _wsumN!(R, A, w, dim, init) |
586 | | - |
587 | 497 | _wsum!(R::AbstractArray, A::AbstractArray, w::AbstractVector, dim::Int, init::Bool) = |
588 | 498 | _wsum_general!(R, identity, A, w, dim, init) |
589 | 499 |
|
590 | 500 | ## wsum! and wsum |
591 | 501 |
|
592 | 502 | wsumtype(::Type{T}, ::Type{W}) where {T,W} = typeof(zero(T) * zero(W) + zero(T) * zero(W)) |
593 | | -wsumtype(::Type{T}, ::Type{T}) where {T<:BlasReal} = T |
594 | 503 |
|
595 | 504 | """ |
596 | 505 | wsum!(R::AbstractArray, A::AbstractArray, |
@@ -663,7 +572,6 @@ _mean!(R::AbstractArray, A::AbstractArray, w::AbstractWeights, dims::Int) = |
663 | 572 | rmul!(Base.sum!(R, A, w, dims), inv(sum(w))) |
664 | 573 |
|
665 | 574 | wmeantype(::Type{T}, ::Type{W}) where {T,W} = typeof((zero(T)*zero(W) + zero(T)*zero(W)) / one(W)) |
666 | | -wmeantype(::Type{T}, ::Type{T}) where {T<:BlasReal} = T |
667 | 575 |
|
668 | 576 | """ |
669 | 577 | mean(A::AbstractArray, w::AbstractWeights[, dims::Int]) |
|
0 commit comments