215
215
"""
216
216
_cumulative(d::UnivariateFinite)
217
217
218
+ **Private method.**
219
+
218
220
Return the cumulative probability vector `C` for the distribution `d`,
219
221
using only classes in the support of `d`, ordered according to the
220
222
categorical elements used at instantiation of `d`. Used only to
238
240
"""
239
241
_rand(rng, p_cumulative, R)
240
242
243
+ **Private method.**
244
+
241
245
Randomly sample the distribution with discrete support `R(1):R(n)`
242
246
which has cumulative probability vector `p_cumulative` (see
243
247
[`_cummulative`](@ref)).
@@ -256,26 +260,27 @@ function _rand(rng, p_cumulative, R)
256
260
return index
257
261
end
258
262
259
- function Base. rand (rng:: AbstractRNG ,
260
- d:: UnivariateFinite{<:Any,<:Any,R} ) where R
261
- p_cumulative = _cumulative (d)
262
- return Dist. support (d)[_rand (rng, p_cumulative, R)]
263
+ Random. eltype (:: Type{<:UnivariateFinite{<:Any,V}} ) where V = V
264
+
265
+ # The Sampler hook into Random's API is discussed in the Julia documentation, in the
266
+ # Standard Library section on Random.
267
+ function Random. Sampler (
268
+ :: AbstractRNG ,
269
+ d:: UnivariateFinite ,
270
+ :: Random.Repetition ,
271
+ )
272
+ data = (_cumulative (d), Dist. support (d))
273
+ Random. SamplerSimple (d, data)
263
274
end
264
275
265
- function Base. rand (rng:: AbstractRNG ,
266
- d:: UnivariateFinite{<:Any,<:Any,R} ,
267
- dim1:: Int , moredims:: Int... ) where R # ref type
268
- p_cumulative = _cumulative (d)
269
- A = Array {R} (undef, dim1, moredims... )
270
- for i in eachindex (A)
271
- @inbounds A[i] = _rand (rng, p_cumulative, R)
272
- end
273
- support = Dist. support (d)
274
- return broadcast (i -> support[i], A)
276
+ function Base. rand (
277
+ rng:: AbstractRNG ,
278
+ sampler:: Random.SamplerSimple{<:UnivariateFinite{<:Any,<:Any,R}} ,
279
+ ) where R
280
+ p_cumulative, support = sampler. data
281
+ return support[_rand (rng, p_cumulative, R)]
275
282
end
276
283
277
- rng (d:: UnivariateFinite , args... ) = rng (Random. GLOBAL_RNG, d, args... )
278
-
279
284
function Dist. fit (d:: Type{<:UnivariateFinite} ,
280
285
v:: AbstractVector{C} ) where C
281
286
C <: CategoricalValue ||
0 commit comments