Skip to content

Commit 17154a2

Browse files
committed
Fix rand(::Beta) inconsistencies
1 parent 0ea5502 commit 17154a2

File tree

3 files changed

+45
-22
lines changed

3 files changed

+45
-22
lines changed

src/univariate/continuous/beta.jl

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -138,15 +138,16 @@ struct BetaSampler{T<:Real, S1 <: Sampleable{Univariate,Continuous},
138138
s2::S2
139139
end
140140

141-
function sampler(d::Beta{T}) where T
142-
(α, β) = params(d)
143-
if (α 1.0) && (β 1.0)
141+
function sampler(d::Beta)
142+
α, β = params(d)
143+
if α 1 && β 1
144144
return BetaSampler(false, inv(α), inv(β),
145-
sampler(Uniform()), sampler(Uniform()))
145+
sampler(Uniform(zero(α), oneunit(α))),
146+
sampler(Uniform(zero(β), oneunit(β))))
146147
else
147148
return BetaSampler(true, inv(α), inv(β),
148-
sampler(Gamma(α, one(T))),
149-
sampler(Gamma(β, one(T))))
149+
sampler(Gamma(α, oneunit))),
150+
sampler(Gamma(β, oneunit))))
150151
end
151152
end
152153

@@ -160,11 +161,11 @@ function rand(rng::AbstractRNG, s::BetaSampler)
160161
= s.
161162
= s.
162163
while true
163-
u = rand(rng) # the Uniform sampler just calls rand()
164-
v = rand(rng)
164+
u = rand(rng, s.s1) # the Uniform sampler just calls rand()
165+
v = rand(rng, s.s2)
165166
x = u^
166167
y = v^
167-
if x + y one(x)
168+
if x + y 1
168169
if (x + y > 0)
169170
return x / (x + y)
170171
else
@@ -180,16 +181,20 @@ function rand(rng::AbstractRNG, s::BetaSampler)
180181
end
181182
end
182183

183-
function rand(rng::AbstractRNG, d::Beta{T}) where T
184-
(α, β) = params(d)
185-
if 1.0) && 1.0)
184+
function rand(rng::AbstractRNG, d::Beta)
185+
α, β = params(d)
186+
if α 1 && β 1
187+
= inv(α)
188+
= inv(β)
189+
Tu = typeof(float(iα))
190+
Tv = typeof(float(iβ))
186191
while true
187-
u = rand(rng)
188-
v = rand(rng)
189-
x = u^inv(α)
190-
y = v^inv(β)
191-
if x + y one(x)
192-
if (x + y > 0)
192+
u = rand(rng, Tu)
193+
v = rand(rng, Tv)
194+
x = u^
195+
y = v^
196+
if x + y 1
197+
if x + y > 0
193198
return x / (x + y)
194199
else
195200
logX = log(u) / α
@@ -202,8 +207,8 @@ function rand(rng::AbstractRNG, d::Beta{T}) where T
202207
end
203208
end
204209
else
205-
g1 = rand(rng, Gamma(α, one(T)))
206-
g2 = rand(rng, Gamma(β, one(T)))
210+
g1 = rand(rng, Gamma(α, oneunit)))
211+
g2 = rand(rng, Gamma(β, oneunit)))
207212
return g1 / (g1 + g2)
208213
end
209214
end

test/runtests.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ const tests = [
7272
"univariate/discrete/poisson",
7373
"univariate/discrete/soliton",
7474
"univariate/continuous/skewnormal",
75+
"univariate/continuous/beta",
7576
"univariate/continuous/chi",
7677
"univariate/continuous/chisq",
7778
"univariate/continuous/erlang",
@@ -129,8 +130,6 @@ const tests = [
129130
# "samplers/vonmisesfisher",
130131
# "show",
131132
# "truncated/loguniform",
132-
# "univariate/continuous/beta",
133-
# "univariate/continuous/beta",
134133
# "univariate/continuous/betaprime",
135134
# "univariate/continuous/biweight",
136135
# "univariate/continuous/cosine",

test/univariate/continuous/beta.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
using Distributions
2+
using Test
3+
4+
@testset "beta.jl" begin
5+
# issue #1907
6+
@testset "rand consistency" begin
7+
for T in (Float32, Float64)
8+
@test @inferred(rand(Beta(T(1), T(1)))) isa T
9+
@test @inferred(rand(Beta(T(4//5), T(4//5)))) isa T
10+
@test @inferred(rand(Beta(T(1), T(2)))) isa T
11+
@test @inferred(rand(Beta(T(2), T(1)))) isa T
12+
13+
@test @inferred(eltype(rand(Beta(T(1), T(1)), 2))) === T
14+
@test @inferred(eltype(rand(Beta(T(4//5), T(4//5)), 2))) === T
15+
@test @inferred(eltype(rand(Beta(T(1), T(2)), 2))) === T
16+
@test @inferred(eltype(rand(Beta(T(2), T(1)), 2))) === T
17+
end
18+
end
19+
end

0 commit comments

Comments
 (0)