Skip to content

Commit e8ab265

Browse files
authored
Improve performance of weighted sum (#778)
The current code is calling the `AbstractArray` matrix multiplication fallback, which is slower than BLAS.
1 parent 5c011db commit e8ab265

File tree

2 files changed

+13
-7
lines changed

2 files changed

+13
-7
lines changed

src/weights.jl

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,18 @@ Compute the weighted sum of an array `v` with weights `w`, optionally over the d
382382
"""
383383
wsum(v::AbstractArray, w::AbstractVector, dims::Colon=:) = transpose(w) * vec(v)
384384

385+
# Optimized methods (to ensure we use BLAS when possible)
386+
for W in (AnalyticWeights, FrequencyWeights, ProbabilityWeights, Weights)
387+
@eval begin
388+
wsum(v::AbstractArray, w::$W, dims::Colon) = transpose(w.values) * vec(v)
389+
end
390+
end
391+
392+
function wsum(A::AbstractArray, w::UnitWeights, dims::Colon)
393+
length(A) != length(w) && throw(DimensionMismatch("Inconsistent array dimension."))
394+
return sum(A)
395+
end
396+
385397
## wsum along dimension
386398
#
387399
# Brief explanation of the algorithm:
@@ -605,12 +617,6 @@ optionally over the dimension `dims`.
605617
Base.sum(A::AbstractArray, w::AbstractWeights{<:Real}; dims::Union{Colon,Int}=:) =
606618
wsum(A, w, dims)
607619

608-
function Base.sum(A::AbstractArray, w::UnitWeights; dims::Union{Colon,Int}=:)
609-
a = (dims === :) ? length(A) : size(A, dims)
610-
a != length(w) && throw(DimensionMismatch("Inconsistent array dimension."))
611-
return sum(A, dims=dims)
612-
end
613-
614620
##### Weighted means #####
615621

616622
function wmean(v::AbstractArray{<:Number}, w::AbstractVector)

test/weights.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -476,7 +476,7 @@ end
476476
@testset "Sum, mean, quantiles and variance for unit weights" begin
477477
wt = uweights(Float64, 3)
478478

479-
@test sum([1.0, 2.0, 3.0], wt) 6.0
479+
@test sum([1.0, 2.0, 3.0], wt) wsum([1.0, 2.0, 3.0], wt) 6.0
480480
@test mean([1.0, 2.0, 3.0], wt) 2.0
481481

482482
@test sum(a, wt, dims=1) sum(a, dims=1)

0 commit comments

Comments
 (0)