Skip to content

Commit be809a8

Browse files
authored
Use a faster and safer implementation of alias_sample! (#927)
* Use a faster implementation for alias_sample! * add invalid weights tests * deprecate make_alias_table!
1 parent 60fb5cd commit be809a8

File tree

4 files changed

+76
-73
lines changed

4 files changed

+76
-73
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ authors = ["JuliaStats"]
44
version = "0.34.3"
55

66
[deps]
7+
AliasTables = "66dad0bd-aa9a-41b7-9441-69ab47430ed8"
78
DataAPI = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
89
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
910
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
@@ -17,6 +18,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1718
StatsAPI = "82ae8749-77ed-4fe6-ae5f-f523153014b0"
1819

1920
[compat]
21+
AliasTables = "1"
2022
DataAPI = "1"
2123
DataStructures = "0.10, 0.11, 0.12, 0.13, 0.14, 0.17, 0.18"
2224
LinearAlgebra = "<0.0.1, 1"

src/deprecates.jl

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,65 @@ end
4646
@deprecate stdm(x::AbstractArray{<:Real}, w::AbstractWeights, m::AbstractArray{<:Real}, dim::Int; corrected::Union{Bool, Nothing}=nothing) std(x, w, dim, mean=m, corrected=corrected) false
4747
@deprecate varm(x::AbstractArray{<:Real}, w::AbstractWeights, m::AbstractArray{<:Real}, dim::Int; corrected::Union{Bool, Nothing}=nothing) var(x, w, dim, mean=m, corrected=corrected) false
4848
@deprecate varm!(R::AbstractArray, x::AbstractArray{<:Real}, w::AbstractWeights, m::AbstractArray{<:Real}, dim::Int; corrected::Union{Bool, Nothing}=nothing) var!(R, x, w, dim, mean=m, corrected=corrected) false
49+
50+
### This was never part of the public API
51+
### Deprecated April 2024
52+
function make_alias_table!(w::AbstractVector, wsum,
53+
a::AbstractVector{Float64},
54+
alias::AbstractVector{Int})
55+
Base.depwarn("make_alias_table! is both internal and deprecated, use AliasTables.jl instead", :make_alias_table!)
56+
# Arguments:
57+
#
58+
# w [in]: input weights
59+
# wsum [in]: pre-computed sum(w)
60+
#
61+
# a [out]: acceptance probabilities
62+
# alias [out]: alias table
63+
#
64+
# Note: a and w can be the same array, then that array will be
65+
# overwritten inplace by acceptance probabilities
66+
#
67+
# Returns nothing
68+
#
69+
70+
n = length(w)
71+
length(a) == length(alias) == n ||
72+
throw(DimensionMismatch("Inconsistent array lengths."))
73+
74+
ac = n / wsum
75+
for i = 1:n
76+
@inbounds a[i] = w[i] * ac
77+
end
78+
79+
larges = Vector{Int}(undef, n)
80+
smalls = Vector{Int}(undef, n)
81+
kl = 0 # actual number of larges
82+
ks = 0 # actual number of smalls
83+
84+
for i = 1:n
85+
@inbounds ai = a[i]
86+
if ai > 1.0
87+
larges[kl+=1] = i # push to larges
88+
elseif ai < 1.0
89+
smalls[ks+=1] = i # push to smalls
90+
end
91+
end
92+
93+
while kl > 0 && ks > 0
94+
s = smalls[ks]; ks -= 1 # pop from smalls
95+
l = larges[kl]; kl -= 1 # pop from larges
96+
@inbounds alias[s] = l
97+
@inbounds al = a[l] = (a[l] - 1.0) + a[s]
98+
if al > 1.0
99+
larges[kl+=1] = l # push to larges
100+
else
101+
smalls[ks+=1] = l # push to smalls
102+
end
103+
end
104+
105+
# this loop should be redundant, except for rounding
106+
for i = 1:ks
107+
@inbounds a[smalls[i]] = 1.0
108+
end
109+
nothing
110+
end

src/sampling.jl

Lines changed: 9 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#
66
###########################################################
77

8+
using AliasTables
89
using Random: Sampler
910

1011
if VERSION < v"1.3.0-DEV.565"
@@ -635,65 +636,6 @@ end
635636
direct_sample!(a::AbstractArray, wv::AbstractWeights, x::AbstractArray) =
636637
direct_sample!(default_rng(), a, wv, x)
637638

638-
function make_alias_table!(w::AbstractVector, wsum,
639-
a::AbstractVector{Float64},
640-
alias::AbstractVector{Int})
641-
# Arguments:
642-
#
643-
# w [in]: input weights
644-
# wsum [in]: pre-computed sum(w)
645-
#
646-
# a [out]: acceptance probabilities
647-
# alias [out]: alias table
648-
#
649-
# Note: a and w can be the same array, then that array will be
650-
# overwritten inplace by acceptance probabilities
651-
#
652-
# Returns nothing
653-
#
654-
655-
n = length(w)
656-
length(a) == length(alias) == n ||
657-
throw(DimensionMismatch("Inconsistent array lengths."))
658-
659-
ac = n / wsum
660-
for i = 1:n
661-
@inbounds a[i] = w[i] * ac
662-
end
663-
664-
larges = Vector{Int}(undef, n)
665-
smalls = Vector{Int}(undef, n)
666-
kl = 0 # actual number of larges
667-
ks = 0 # actual number of smalls
668-
669-
for i = 1:n
670-
@inbounds ai = a[i]
671-
if ai > 1.0
672-
larges[kl+=1] = i # push to larges
673-
elseif ai < 1.0
674-
smalls[ks+=1] = i # push to smalls
675-
end
676-
end
677-
678-
while kl > 0 && ks > 0
679-
s = smalls[ks]; ks -= 1 # pop from smalls
680-
l = larges[kl]; kl -= 1 # pop from larges
681-
@inbounds alias[s] = l
682-
@inbounds al = a[l] = (a[l] - 1.0) + a[s]
683-
if al > 1.0
684-
larges[kl+=1] = l # push to larges
685-
else
686-
smalls[ks+=1] = l # push to smalls
687-
end
688-
end
689-
690-
# this loop should be redundant, except for rounding
691-
for i = 1:ks
692-
@inbounds a[smalls[i]] = 1.0
693-
end
694-
nothing
695-
end
696-
697639
"""
698640
alias_sample!([rng], a::AbstractArray, wv::AbstractWeights, x::AbstractArray)
699641
@@ -704,29 +646,23 @@ Build an alias table, and sample therefrom.
704646
Reference: Walker, A. J. "An Efficient Method for Generating Discrete Random Variables
705647
with General Distributions." *ACM Transactions on Mathematical Software* 3 (3): 253, 1977.
706648
707-
Noting `k=length(x)` and `n=length(a)`, this algorithm takes ``O(n \\log n)`` time
708-
for building the alias table, and then ``O(1)`` to draw each sample. It consumes ``2 k`` random numbers.
649+
Noting `k=length(x)` and `n=length(a)`, this algorithm takes ``O(n)`` time
650+
for building the alias table, and then ``O(1)`` to draw each sample. It consumes ``k`` random numbers.
709651
"""
710652
function alias_sample!(rng::AbstractRNG, a::AbstractArray, wv::AbstractWeights, x::AbstractArray)
711653
Base.mightalias(a, x) &&
712654
throw(ArgumentError("output array x must not share memory with input array a"))
713-
Base.mightalias(x, wv) &&
714-
throw(ArgumentError("output array x must not share memory with weights array wv"))
715-
1 == firstindex(a) == firstindex(wv) == firstindex(x) ||
655+
1 == firstindex(a) == firstindex(wv) ||
716656
throw(ArgumentError("non 1-based arrays are not supported"))
717-
n = length(a)
718-
length(wv) == n || throw(DimensionMismatch("Inconsistent lengths."))
657+
length(wv) == length(a) || throw(DimensionMismatch("Inconsistent lengths."))
719658

720659
# create alias table
721-
ap = Vector{Float64}(undef, n)
722-
alias = Vector{Int}(undef, n)
723-
make_alias_table!(wv, sum(wv), ap, alias)
660+
at = AliasTable(wv)
724661

725662
# sampling
726-
s = Sampler(rng, 1:n)
727-
for i = 1:length(x)
728-
j = rand(rng, s)
729-
x[i] = rand(rng) < ap[j] ? a[j] : a[alias[j]]
663+
for i in eachindex(x)
664+
j = rand(rng, at)
665+
x[i] = a[j]
730666
end
731667
return x
732668
end

test/wsampling.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@ for wv in (
5555
check_wsample_wrep(a, (4, 7), wv, 5.0e-3; ordered=false)
5656
end
5757

58+
@test_throws ArgumentError alias_sample!(rand(10), weights(fill(0, 10)), rand(10))
59+
@test_throws ArgumentError alias_sample!(rand(100), weights(randn(100)), rand(10))
60+
5861
for rev in (true, false), T in (Int, Int16, Float64, Float16, BigInt, ComplexF64, Rational{Int})
5962
r = rev ? reverse(4:7) : (4:7)
6063
r = T===Int ? r : T.(r)

0 commit comments

Comments
 (0)