Skip to content

Commit 9916e14

Browse files
authored
fix weighted computations for non-real arrays (#737)
1 parent 1678fd1 commit 9916e14

File tree

2 files changed

+4
-3
lines changed

2 files changed

+4
-3
lines changed

src/weights.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -380,9 +380,7 @@ Base.:(==)(x::AbstractWeights, y::AbstractWeights) = false
380380
381381
Compute the weighted sum of an array `v` with weights `w`, optionally over the dimension `dim`.
382382
"""
383-
wsum(v::AbstractVector, w::AbstractVector) = dot(v, w)
384-
wsum(v::AbstractArray, w::AbstractVector) = dot(vec(v), w)
385-
wsum(v::AbstractArray, w::AbstractVector, dims::Colon) = wsum(v, w)
383+
wsum(v::AbstractArray, w::AbstractVector, dims::Colon=:) = transpose(w) * vec(v)
386384

387385
## wsum along dimension
388386
#

test/weights.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,8 @@ a = reshape(1.0:27.0, 3, 3, 3)
239239
@testset "Sum $f" for f in weight_funcs
240240
@test sum([1.0, 2.0, 3.0], f([1.0, 0.5, 0.5])) 3.5
241241
@test sum(1:3, f([1.0, 1.0, 0.5])) 4.5
242+
@test sum([1 + 2im, 2 + 3im], f([1.0, 0.5])) 2 + 3.5im
243+
@test sum([[1, 2], [3, 4]], f([2, 3])) == [11, 16]
242244

243245
for wt in ([1.0, 1.0, 1.0], [1.0, 0.2, 0.0], [0.2, 0.0, 1.0])
244246
@test sum(a, f(wt), dims=1) sum(a.*reshape(wt, length(wt), 1, 1), dims=1)
@@ -250,6 +252,7 @@ end
250252
@testset "Mean $f" for f in weight_funcs
251253
@test mean([1:3;], f([1.0, 1.0, 0.5])) 1.8
252254
@test mean(1:3, f([1.0, 1.0, 0.5])) 1.8
255+
@test mean([1 + 2im, 4 + 5im], f([1.0, 0.5])) 2 + 3im
253256

254257
for wt in ([1.0, 1.0, 1.0], [1.0, 0.2, 0.0], [0.2, 0.0, 1.0])
255258
@test mean(a, f(wt), dims=1) sum(a.*reshape(wt, length(wt), 1, 1), dims=1)/sum(wt)

0 commit comments

Comments
 (0)