Skip to content

Commit 3b25647

Browse files
findmywaydevmotion
andauthored
support Float32 for alias sampling (#499)
* support Float32 for alias sampling * update tests * rm unnecessary sig * resolve comments * Update src/sampling.jl Co-authored-by: David Widmann <[email protected]> Co-authored-by: David Widmann <[email protected]>
1 parent c4432ab commit 3b25647

File tree

2 files changed

+18
-11
lines changed

2 files changed

+18
-11
lines changed

src/sampling.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -584,7 +584,7 @@ end
584584
direct_sample!(a::AbstractArray, wv::AbstractWeights, x::AbstractArray) =
585585
direct_sample!(Random.GLOBAL_RNG, a, wv, x)
586586

587-
function make_alias_table!(w::AbstractVector{Float64}, wsum::Float64,
587+
function make_alias_table!(w::AbstractVector, wsum,
588588
a::AbstractVector{Float64},
589589
alias::AbstractVector{Int})
590590
# Arguments:

test/wsampling.jl

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -37,15 +37,23 @@ import StatsBase: direct_sample!, alias_sample!
3737
n = 10^6
3838
wv = weights([0.2, 0.8, 0.4, 0.6])
3939

40-
a = direct_sample!(4:7, wv, zeros(Int, n, 3))
41-
check_wsample_wrep(a, (4, 7), wv, 5.0e-3; ordered=false)
42-
test_rng_use(direct_sample!, 4:7, wv, zeros(Int, 100))
43-
44-
a = alias_sample!(4:7, wv, zeros(Int, n, 3))
45-
check_wsample_wrep(a, (4, 7), wv, 5.0e-3; ordered=false)
46-
47-
a = sample(4:7, wv, n; ordered=false)
48-
check_wsample_wrep(a, (4, 7), wv, 5.0e-3; ordered=false)
40+
for wv in [
41+
weights([0.2, 0.8, 0.4, 0.6]),
42+
weights([2, 8, 4, 6]),
43+
weights(Float32[0.2, 0.8, 0.4, 0.6]),
44+
Weights(Float32[0.2, 0.8, 0.4, 0.6], 2),
45+
Weights([2, 8, 4, 6], 20.0),
46+
]
47+
a = direct_sample!(4:7, wv, zeros(Int, n, 3))
48+
check_wsample_wrep(a, (4, 7), wv, 5.0e-3; ordered=false)
49+
test_rng_use(direct_sample!, 4:7, wv, zeros(Int, 100))
50+
51+
a = alias_sample!(4:7, wv, zeros(Int, n, 3))
52+
check_wsample_wrep(a, (4, 7), wv, 5.0e-3; ordered=false)
53+
54+
a = sample(4:7, wv, n; ordered=false)
55+
check_wsample_wrep(a, (4, 7), wv, 5.0e-3; ordered=false)
56+
end
4957

5058
for rev in (true, false), T in (Int, Int16, Float64, Float16, BigInt, ComplexF64, Rational{Int})
5159
r = rev ? reverse(4:7) : (4:7)
@@ -56,7 +64,6 @@ for rev in (true, false), T in (Int, Int16, Float64, Float16, BigInt, ComplexF64
5664
check_wsample_wrep(aa, (4, 7), wv, -1; ordered=true, rev=rev)
5765
end
5866

59-
6067
#### weighted sampling without replacement
6168

6269
function check_wsample_norep(a::AbstractArray, vrgn, wv::AbstractWeights, ptol::Real;

0 commit comments

Comments
 (0)