1
1
# not for export:
2
- const UnivariateFiniteUnion =
3
- Union{UnivariateFinite, UnivariateFiniteArray}
2
+ const UnivariateFiniteUnion{S,V,R,P} =
3
+ Union{UnivariateFinite{S,V,R,P}, UnivariateFiniteArray{S,V,R,P} }
4
4
5
5
"""
6
6
classes(d::UnivariateFinite)
42
42
raw_support (d:: UnivariateFiniteUnion ) = collect (keys (d. prob_given_ref))
43
43
44
44
"""
45
- Dist .support(d::UnivariateFinite)
46
- Dist .support(d::UnivariateFiniteArray)
45
+ Distributions .support(d::UnivariateFinite)
46
+ Distributions .support(d::UnivariateFiniteArray)
47
47
48
48
Ordered list of classes associated with non-zero probabilities.
49
49
50
50
v = categorical(["yes", "maybe", "no", "yes"])
51
51
d = UnivariateFinite(v[1:2], [0.3, 0.7])
52
- support(d) # CategoricalArray{String,1,UInt32}["maybe", "yes"]
52
+ Distributions. support(d) # CategoricalArray{String,1,UInt32}["maybe", "yes"]
53
53
54
54
"""
55
- Dist. support (d:: UnivariateFiniteUnion ) =
56
- map (d. decoder, raw_support (d))
55
+ Distributions. support (d:: UnivariateFiniteUnion ) = classes (d)[raw_support (d)]
56
+
57
+ """
58
+ fast_support(d::UnivariateFinite)
59
+
60
+ Same as `Distributions.support(d)` except it returns a vector of `CategoricalValue`s,
61
+ rather than a `CategoricalVector`. It executes faster, about five times faster for a
62
+ three-class `UnivariateFinite` distribution.
63
+ """
64
+ function fast_support (d:: UnivariateFiniteUnion{S,V,R} ) where {S,V,R}
65
+ raw_support = keys (d. prob_given_ref)
66
+ n = length (raw_support)
67
+ ret = Vector {CategoricalValue{V,R}} (undef, n)
68
+ for (i, ref) in enumerate (raw_support)
69
+ ret[i] = d. decoder (ref)
70
+ end
71
+ ret
72
+ end
57
73
58
74
# TODO : If I manually give a class zero probability, it will appear in
59
75
# support, which is probably confusing. We may need two versions of
@@ -64,8 +80,7 @@ Dist.support(d::UnivariateFiniteUnion) =
64
80
# not exported:
65
81
sample_scitype (d:: UnivariateFiniteUnion ) = d. scitype
66
82
67
- CategoricalArrays. isordered (d:: UnivariateFinite ) = isordered (classes (d))
68
- CategoricalArrays. isordered (u:: UnivariateFiniteArray ) = isordered (classes (u))
83
+ CategoricalArrays. isordered (d:: UnivariateFiniteUnion ) = isordered (classes (d))
69
84
70
85
71
86
# # DISPLAY
@@ -96,8 +111,8 @@ probability pairs. Returns `false` otherwise.
96
111
97
112
"""
98
113
function Base. isapprox (d1:: UnivariateFinite , d2:: UnivariateFinite ; kwargs... )
99
- support1 = Dist . support (d1)
100
- support2 = Dist . support (d2)
114
+ support1 = fast_support (d1)
115
+ support2 = fast_support (d2)
101
116
for c in support1
102
117
c in support2 || return false
103
118
isapprox (pdf (d1, c), pdf (d2, c); kwargs... ) ||
@@ -107,8 +122,8 @@ function Base.isapprox(d1::UnivariateFinite, d2::UnivariateFinite; kwargs...)
107
122
end
108
123
function Base. isapprox (d1:: UnivariateFiniteArray ,
109
124
d2:: UnivariateFiniteArray ; kwargs... )
110
- support1 = Dist . support (d1)
111
- support2 = Dist . support (d2)
125
+ support1 = fast_support (d1)
126
+ support2 = fast_support (d2)
112
127
for c in support1
113
128
c in support2 || return false
114
129
isapprox (pdf .(d1, c), pdf .(d2, c); kwargs... ) ||
@@ -206,22 +221,18 @@ function throw_nan_error_if_needed(x)
206
221
end
207
222
end
208
223
209
- # mode(v::Vector{UnivariateFinite}) = mode.(v)
210
- # mode(u::UnivariateFiniteVector{2}) =
211
- # [u.support[ifelse(s > 0.5, 2, 1)] for s in u.scores]
212
- # mode(u::UnivariateFiniteVector{C}) where {C} =
213
- # [u.support[findmax(s)[2]] for s in eachrow(u.scores)]
224
+
225
+ # # HELPERS FOR RAND
214
226
215
227
"""
216
228
_cumulative(d::UnivariateFinite)
217
229
218
230
**Private method.**
219
231
220
- Return the cumulative probability vector `C` for the distribution `d`,
221
- using only classes in the support of `d`, ordered according to the
222
- categorical elements used at instantiation of `d`. Used only to
223
- implement random sampling from `d`. We have `C[1] == 0` and `C[end] ==
224
- 1`, assuming the probabilities have been normalized.
232
+ Return the cumulative probability vector `C` for the distribution `d`, using only classes
233
+ in `Distributions.support(d)`, ordered according to the categorical elements used at
234
+ instantiation of `d`. Used only to implement random sampling from `d`. We have `C[1] == 0`
235
+ and `C[end] == 1`, assuming the probabilities have been normalized.
225
236
226
237
"""
227
238
function _cumulative (d:: UnivariateFinite{S,V,R,P} ) where {S,V,R,P<: Real }
@@ -260,16 +271,54 @@ function _rand(rng, p_cumulative, R)
260
271
return index
261
272
end
262
273
263
- Random. eltype (:: Type{<:UnivariateFinite{<:Any,V}} ) where V = V
274
+
275
+ # # RAND
276
+
277
+ Random. eltype (:: Type{<:UnivariateFinite{S,V,R}} ) where {S,V,R} =
278
+ CategoricalArrays. CategoricalValue{V,R}
264
279
265
280
# The Sampler hook into Random's API is discussed in the Julia documentation, in the
266
281
# Standard Library section on Random.
282
+
283
+
284
+ # # Single samples
285
+
286
+ Random. Sampler (:: AbstractRNG , d:: UnivariateFinite , :: Val{1} ) = Random. SamplerTrivial (d)
287
+
288
+ function Base. rand (
289
+ rng:: AbstractRNG ,
290
+ sampler:: Random.SamplerTrivial{<:UnivariateFinite{<:Any,<:Any,V,P}} ,
291
+ ) where {V, P}
292
+
293
+ d = sampler[]
294
+ u = rand (rng)
295
+
296
+ total = zero (P)
297
+
298
+ # For type stability we assign `zero(V)`` as the default ref
299
+ # This isn't a problem since we know that `rand` is always defined
300
+ # as UnivariateFinite objects have non-negative probabilities,
301
+ # summing up to a non-negative value.
302
+ rng_key = zero (V)
303
+ for (ref, prob) in pairs (d. prob_given_ref)
304
+ total += prob
305
+ u <= total && begin
306
+ rng_key = ref
307
+ break
308
+ end
309
+ end
310
+ return d. decoder (rng_key)
311
+ end
312
+
313
+
314
+ # # Multiple samples
315
+
267
316
function Random. Sampler (
268
317
:: AbstractRNG ,
269
318
d:: UnivariateFinite ,
270
319
:: Random.Repetition ,
271
320
)
272
- data = (_cumulative (d), Dist . support (d))
321
+ data = (_cumulative (d), fast_support (d))
273
322
Random. SamplerSimple (d, data)
274
323
end
275
324
@@ -281,6 +330,9 @@ function Base.rand(
281
330
return support[_rand (rng, p_cumulative, R)]
282
331
end
283
332
333
+
334
+ # # FIT
335
+
284
336
function Dist. fit (d:: Type{<:UnivariateFinite} ,
285
337
v:: AbstractVector{C} ) where C
286
338
C <: CategoricalValue ||
0 commit comments