Skip to content

Commit 2cc3d3d

Browse files
authored
Merge pull request #66 from JuliaAI/rand-fix
Fix implementation of `rand`
2 parents 5d3934c + db73962 commit 2cc3d3d

File tree

2 files changed

+45
-17
lines changed

2 files changed

+45
-17
lines changed

src/methods.jl

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,8 @@ end
215215
"""
216216
_cumulative(d::UnivariateFinite)
217217
218+
**Private method.**
219+
218220
Return the cumulative probability vector `C` for the distribution `d`,
219221
using only classes in the support of `d`, ordered according to the
220222
categorical elements used at instantiation of `d`. Used only to
@@ -238,6 +240,8 @@ end
238240
"""
239241
_rand(rng, p_cumulative, R)
240242
243+
**Private method.**
244+
241245
Randomly sample the distribution with discrete support `R(1):R(n)`
242246
which has cumulative probability vector `p_cumulative` (see
243247
[`_cummulative`](@ref)).
@@ -256,26 +260,27 @@ function _rand(rng, p_cumulative, R)
256260
return index
257261
end
258262

259-
function Base.rand(rng::AbstractRNG,
260-
d::UnivariateFinite{<:Any,<:Any,R}) where R
261-
p_cumulative = _cumulative(d)
262-
return Dist.support(d)[_rand(rng, p_cumulative, R)]
263+
Random.eltype(::Type{<:UnivariateFinite{<:Any,V}}) where V = V
264+
265+
# The Sampler hook into Random's API is discussed in the Julia documentation, in the
266+
# Standard Library section on Random.
267+
function Random.Sampler(
268+
::AbstractRNG,
269+
d::UnivariateFinite,
270+
::Random.Repetition,
271+
)
272+
data = (_cumulative(d), Dist.support(d))
273+
Random.SamplerSimple(d, data)
263274
end
264275

265-
function Base.rand(rng::AbstractRNG,
266-
d::UnivariateFinite{<:Any,<:Any,R},
267-
dim1::Int, moredims::Int...) where R # ref type
268-
p_cumulative = _cumulative(d)
269-
A = Array{R}(undef, dim1, moredims...)
270-
for i in eachindex(A)
271-
@inbounds A[i] = _rand(rng, p_cumulative, R)
272-
end
273-
support = Dist.support(d)
274-
return broadcast(i -> support[i], A)
276+
function Base.rand(
277+
rng::AbstractRNG,
278+
sampler::Random.SamplerSimple{<:UnivariateFinite{<:Any,<:Any,R}},
279+
) where R
280+
p_cumulative, support = sampler.data
281+
return support[_rand(rng, p_cumulative, R)]
275282
end
276283

277-
rng(d::UnivariateFinite, args...) = rng(Random.GLOBAL_RNG, d, args...)
278-
279284
function Dist.fit(d::Type{<:UnivariateFinite},
280285
v::AbstractVector{C}) where C
281286
C <: CategoricalValue ||

test/methods.jl

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ using StableRNGs
88
import Random
99
rng = StableRNG(123)
1010
using ScientificTypes
11+
import Random.default_rng
1112

1213
import CategoricalDistributions: classes, ERR_NAN_FOUND
1314

@@ -127,7 +128,7 @@ end
127128
@testset "broadcasting pdf over single UnivariateFinite object" begin
128129
d = UnivariateFinite(["a", "b"], [0.1, 0.9], pool=missing);
129130
@test pdf.(d, ["a", "b"]) == [0.1, 0.9]
130-
end
131+
end
131132

132133
@testset "constructor arguments not categorical values" begin
133134
@test_throws ArgumentError UnivariateFinite(Dict('f'=>0.7, 'q'=>0.2))
@@ -299,6 +300,28 @@ end
299300
@test displays_okay([5 + 3im, 4 - 7im])
300301
end
301302

303+
@testset "rand signatures" begin
304+
d = UnivariateFinite(
305+
["maybe", "no", "yes"],
306+
[0.5, 0.4, 0.1];
307+
pool=missing,
308+
)
309+
310+
# smoke test:
311+
sampler = Random.Sampler(default_rng(), d, Val(1))
312+
rand(default_rng(), sampler)
313+
314+
Random.seed!(123)
315+
samples = [rand(default_rng(), d) for i in 1:30]
316+
Random.seed!(123)
317+
@test [rand(d) for i in 1:30] == samples
318+
319+
Random.seed!(123)
320+
samples = rand(Random.default_rng(), d, 3, 5)
321+
Random.seed!(123)
322+
@test samples == rand(d, 3, 5)
323+
end
324+
302325
end # module
303326

304327
true

0 commit comments

Comments
 (0)