Skip to content

Commit 983bf02

Browse files
Fix indexing UnitWeights with a Boolean vector (#603)
1 parent fbf9e66 commit 983bf02

File tree

2 files changed

+6
-0
lines changed

2 files changed

+6
-0
lines changed

src/weights.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,11 @@ end
289289
UnitWeights{T}(length(i))
290290
end
291291

292+
function Base.getindex(wv::UnitWeights{T}, i::AbstractArray{Bool}) where T
293+
length(wv) == length(i) || throw(DimensionMismatch())
294+
UnitWeights{T}(count(i))
295+
end
296+
292297
Base.getindex(wv::UnitWeights{T}, ::Colon) where {T} = UnitWeights{T}(wv.len)
293298

294299
"""

test/weights.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ end
112112
@test isequal(wv, uweights(3))
113113
@test wv != fweights(fill(1.0, 3))
114114
@test wv == uweights(3)
115+
@test wv[[true, false, false]] == uweights(Float64, 1)
115116
end
116117

117118
## wsum

0 commit comments

Comments
 (0)