Skip to content

Commit db73962

Browse files
committed
instead, implement Random.Sampler and Rand.rand(rng, ::Sampler{...})
1 parent df4ee21 commit db73962

File tree

2 files changed

+38
-37
lines changed

2 files changed

+38
-37
lines changed

src/methods.jl

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -260,27 +260,26 @@ function _rand(rng, p_cumulative, R)
260260
return index
261261
end
262262

263-
function Base.rand(rng::AbstractRNG,
264-
d::UnivariateFinite{<:Any,<:Any,R}) where R
265-
p_cumulative = _cumulative(d)
266-
return Dist.support(d)[_rand(rng, p_cumulative, R)]
267-
end
268-
Base.rand(d::UnivariateFinite) = rand(Random.default_rng(), d)
269-
270-
function Base.rand(rng::AbstractRNG,
271-
d::UnivariateFinite{<:Any,<:Any,R},
272-
dim1::Integer, moredims::Integer...) where R # ref type
273-
p_cumulative = _cumulative(d)
274-
A = Array{R}(undef, dim1, moredims...)
275-
for i in eachindex(A)
276-
@inbounds A[i] = _rand(rng, p_cumulative, R)
277-
end
278-
support = Dist.support(d)
279-
return broadcast(i -> support[i], A)
263+
Random.eltype(::Type{<:UnivariateFinite{<:Any,V}}) where V = V
264+
265+
# The Sampler hook into Random's API is discussed in the Julia documentation, in the
266+
# Standard Library section on Random.
267+
function Random.Sampler(
268+
::AbstractRNG,
269+
d::UnivariateFinite,
270+
::Random.Repetition,
271+
)
272+
data = (_cumulative(d), Dist.support(d))
273+
Random.SamplerSimple(d, data)
280274
end
281275

282-
Base.rand(d::UnivariateFinite, dim1::Integer, moredims::Integer...) =
283-
rand(Random.default_rng(), d, dim1, moredims...)
276+
function Base.rand(
277+
rng::AbstractRNG,
278+
sampler::Random.SamplerSimple{<:UnivariateFinite{<:Any,<:Any,R}},
279+
) where R
280+
p_cumulative, support = sampler.data
281+
return support[_rand(rng, p_cumulative, R)]
282+
end
284283

285284
function Dist.fit(d::Type{<:UnivariateFinite},
286285
v::AbstractVector{C}) where C

test/methods.jl

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -300,24 +300,26 @@ end
300300
@test displays_okay([5 + 3im, 4 - 7im])
301301
end
302302

303-
if VERSION >= v"1.7"
304-
@testset "rand signatures" begin
305-
d = UnivariateFinite(
306-
["maybe", "no", "yes"],
307-
[0.5, 0.4, 0.1];
308-
pool=missing,
309-
)
310-
311-
Random.seed!(123)
312-
samples = [rand(default_rng(), d) for i in 1:30]
313-
Random.seed!(123)
314-
@test [rand(d) for i in 1:30] == samples
315-
316-
Random.seed!(123)
317-
samples = rand(Random.default_rng(), d, 3, 5)
318-
Random.seed!(123)
319-
@test samples == rand(d, 3, 5)
320-
end
303+
@testset "rand signatures" begin
304+
d = UnivariateFinite(
305+
["maybe", "no", "yes"],
306+
[0.5, 0.4, 0.1];
307+
pool=missing,
308+
)
309+
310+
# smoke test:
311+
sampler = Random.Sampler(default_rng(), d, Val(1))
312+
rand(default_rng(), sampler)
313+
314+
Random.seed!(123)
315+
samples = [rand(default_rng(), d) for i in 1:30]
316+
Random.seed!(123)
317+
@test [rand(d) for i in 1:30] == samples
318+
319+
Random.seed!(123)
320+
samples = rand(Random.default_rng(), d, 3, 5)
321+
Random.seed!(123)
322+
@test samples == rand(d, 3, 5)
321323
end
322324

323325
end # module

0 commit comments

Comments
 (0)