Skip to content

Commit 004a468

Browse files
committed
fix implementation of rand to close #65
1 parent 66f23a3 commit 004a468

File tree

2 files changed

+31
-4
lines changed

2 files changed

+31
-4
lines changed

src/methods.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,8 @@ end
215215
"""
216216
_cumulative(d::UnivariateFinite)
217217
218+
**Private method.**
219+
218220
Return the cumulative probability vector `C` for the distribution `d`,
219221
using only classes in the support of `d`, ordered according to the
220222
categorical elements used at instantiation of `d`. Used only to
@@ -238,13 +240,15 @@ end
238240
"""
239241
_rand(rng, p_cumulative, R)
240242
243+
**Private method.**
244+
241245
Randomly sample the distribution with discrete support `R(1):R(n)`
242246
which has cumulative probability vector `p_cumulative` (see
243247
[`_cummulative`](@ref)).
244248
245249
"""
246250
function _rand(rng, p_cumulative, R)
247-
real_sample = rand(rng)*p_cumulative[end]
251+
real_sample = Base.rand(rng)*p_cumulative[end]
248252
K = R(length(p_cumulative))
249253
index = K
250254
for i in R(2):R(K)
@@ -261,10 +265,11 @@ function Base.rand(rng::AbstractRNG,
261265
p_cumulative = _cumulative(d)
262266
return Dist.support(d)[_rand(rng, p_cumulative, R)]
263267
end
268+
Base.rand(d::UnivariateFinite) = rand(Random.default_rng(), d)
264269

265270
function Base.rand(rng::AbstractRNG,
266271
d::UnivariateFinite{<:Any,<:Any,R},
267-
dim1::Int, moredims::Int...) where R # ref type
272+
dim1::Integer, moredims::Integer...) where R # ref type
268273
p_cumulative = _cumulative(d)
269274
A = Array{R}(undef, dim1, moredims...)
270275
for i in eachindex(A)
@@ -274,7 +279,8 @@ function Base.rand(rng::AbstractRNG,
274279
return broadcast(i -> support[i], A)
275280
end
276281

277-
rng(d::UnivariateFinite, args...) = rng(Random.GLOBAL_RNG, d, args...)
282+
Base.rand(d::UnivariateFinite, dim1::Integer, moredims::Integer...) =
283+
rand(Random.default_rng(), d, dim1, moredims...)
278284

279285
function Dist.fit(d::Type{<:UnivariateFinite},
280286
v::AbstractVector{C}) where C

test/methods.jl

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ using StableRNGs
88
import Random
99
rng = StableRNG(123)
1010
using ScientificTypes
11+
import Random.default_rng
1112

1213
import CategoricalDistributions: classes, ERR_NAN_FOUND
1314

@@ -127,7 +128,7 @@ end
127128
@testset "broadcasting pdf over single UnivariateFinite object" begin
128129
d = UnivariateFinite(["a", "b"], [0.1, 0.9], pool=missing);
129130
@test pdf.(d, ["a", "b"]) == [0.1, 0.9]
130-
end
131+
end
131132

132133
@testset "constructor arguments not categorical values" begin
133134
@test_throws ArgumentError UnivariateFinite(Dict('f'=>0.7, 'q'=>0.2))
@@ -299,6 +300,26 @@ end
299300
@test displays_okay([5 + 3im, 4 - 7im])
300301
end
301302

303+
if VERSION >= v"1.7"
304+
@testset "rand signatures" begin
305+
d = UnivariateFinite(
306+
["maybe", "no", "yes"],
307+
[0.5, 0.4, 0.1];
308+
pool=missing,
309+
)
310+
311+
Random.seed!(123)
312+
samples = [rand(default_rng(), d) for i in 1:30]
313+
Random.seed!(123)
314+
@test [rand(d) for i in 1:30] == samples
315+
316+
Random.seed!(123)
317+
samples = rand(Random.default_rng(), d, 3, 5)
318+
Random.seed!(123)
319+
@test samples == rand(d, 3, 5)
320+
end
321+
end
322+
302323
end # module
303324

304325
true

0 commit comments

Comments
 (0)