Skip to content

Commit b670fee

Browse files
Use a faster implementation of AliasTables (#1848)
* switch to AliasTables.jl * retune heuristic * add test for #832 * add more tests * move alias table import and tighten from using to import * Back out multinomial heuristic adjustment at @adienes's request * Update test/univariate/discrete/categorical.jl (style) Co-authored-by: David Widmann <[email protected]> --------- Co-authored-by: David Widmann <[email protected]>
1 parent f33af97 commit b670fee

File tree

5 files changed

+25
-21
lines changed

5 files changed

+25
-21
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ authors = ["JuliaStats"]
44
version = "0.25.107"
55

66
[deps]
7+
AliasTables = "66dad0bd-aa9a-41b7-9441-69ab47430ed8"
78
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
89
DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d"
910
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
@@ -30,6 +31,7 @@ DistributionsDensityInterfaceExt = "DensityInterface"
3031
DistributionsTestExt = "Test"
3132

3233
[compat]
34+
AliasTables = "1"
3335
Aqua = "0.8"
3436
Calculus = "0.5"
3537
ChainRulesCore = "1"

src/Distributions.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ import PDMats: dim, PDMat, invquad
2727
using SpecialFunctions
2828
using Base.MathConstants: eulergamma
2929

30+
import AliasTables
31+
3032
export
3133
# re-export Statistics
3234
mean, median, quantile, std, var, cov, cor,

src/samplers/aliastable.jl

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,7 @@
11
struct AliasTable <: Sampleable{Univariate,Discrete}
2-
accept::Vector{Float64}
3-
alias::Vector{Int}
2+
at::AliasTables.AliasTable{UInt64, Int}
3+
AliasTable(probs::AbstractVector{<:Real}) = new(AliasTables.AliasTable(probs))
44
end
5-
ncategories(s::AliasTable) = length(s.alias)
6-
7-
function AliasTable(probs::AbstractVector)
8-
n = length(probs)
9-
n > 0 || throw(ArgumentError("The input probability vector is empty."))
10-
accp = Vector{Float64}(undef, n)
11-
alias = Vector{Int}(undef, n)
12-
StatsBase.make_alias_table!(probs, 1.0, accp, alias)
13-
AliasTable(accp, alias)
14-
end
15-
16-
function rand(rng::AbstractRNG, s::AliasTable)
17-
i = rand(rng, 1:length(s.alias)) % Int
18-
# using `ifelse` improves performance here: github.com/JuliaStats/Distributions.jl/pull/1831/
19-
ifelse(rand(rng) < s.accept[i], i, s.alias[i])
20-
end
21-
5+
ncategories(s::AliasTable) = length(s.at)
6+
rand(rng::AbstractRNG, s::AliasTable) = rand(rng, s.at)
227
show(io::IO, s::AliasTable) = @printf(io, "AliasTable with %d entries", ncategories(s))

test/samplers.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,11 @@ import Distributions:
3131
@testset "p=$p" for p in Any[[1.0], [0.3, 0.7], [0.2, 0.3, 0.4, 0.1]]
3232
test_samples(S(p), Categorical(p), n_tsamples)
3333
test_samples(S(p), Categorical(p), n_tsamples, rng=rng)
34+
@test ncategories(S(p)) == length(p)
3435
end
3536
end
3637

38+
@test string(AliasTable(Float16[1,2,3])) == "AliasTable with 3 entries"
3739

3840
## Binomial samplers
3941

test/univariate/discrete/categorical.jl

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,9 @@ end
9393
end
9494

9595
@testset "reproducibility across julia versions" begin
96-
d= Categorical([0.1, 0.2, 0.7])
96+
d = Categorical([0.1, 0.2, 0.7])
9797
rng = StableRNGs.StableRNG(600)
98-
@test rand(rng, d, 10) == [2, 1, 3, 3, 2, 3, 3, 3, 3, 3]
98+
@test rand(rng, d, 10) == [3, 1, 1, 2, 3, 2, 3, 3, 2, 3]
9999
end
100100

101101
@testset "comparisons" begin
@@ -124,4 +124,17 @@ end
124124
@test Categorical([0.5, 0.5]) Categorical([0.5f0, 0.5f0])
125125
end
126126

127+
@testset "issue #832" begin
128+
priorities = collect(Float64, 1:1000)
129+
priorities[1:50] .= 1e8
130+
131+
at = Distributions.AliasTable(priorities)
132+
iat = rand(at, 16)
133+
134+
# failure rate of a single sample is sum(51:1000)/50e8 = 9.9845e-5
135+
# failure rate of 4 out of 16 samples is 1-cdf(Binomial(16, 9.9845e-5), 3) = 1.8074430840897548e-13
136+
# this test should randomly fail with a probability of 1.8074430840897548e-13
137+
@test count(==(1e8), priorities[iat]) >= 13
138+
end
139+
127140
end

0 commit comments

Comments
 (0)