Skip to content

Commit cab4cfb

Browse files
committed
Dump UnivariateFiniteArray as redundant.
doc-string tweaks and again
1 parent 09e821e commit cab4cfb

File tree

3 files changed

+49
-54
lines changed

3 files changed

+49
-54
lines changed

src/MLJModelInterface.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ export LightInterface, FullInterface
1414
# MLJ model hierarchy
1515
export MLJType, Model, Supervised, Unsupervised,
1616
Probabilistic, Deterministic, Interval, Static,
17-
UnivariateFinite, UnivariateFiniteArray
17+
UnivariateFinite
1818

1919
# model constructor + metadata
2020
export @mlj_model, metadata_pkg, metadata_model

src/data_utils.jl

Lines changed: 48 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -300,38 +300,54 @@ const UNIVARIATE_FINITE_DOCSTRING =
300300
301301
Construct a discrete univariate distribution whose finite support is
302302
the elements of the vector `support`, and whose corresponding
303-
probabilities are elements of the vector `probs`, which must sum to
304-
one.
303+
probabilities are elements of the vector `probs`. More generally,
304+
construct an abstract *array* of `UnivariateFinite` distributions by
305+
choosing `probs` to be an array of one higher dimension than the array
306+
generated.
305307
306-
Unless `pool` is specified, `support` must have type
308+
Unless `pool` is specified, `support` should have type
307309
`AbstractVector{<:CategoricalValue}` and all elements are assumed to
308310
share the same categorical pool.
309311
312+
*Important.* All levels of the common pool have associated
313+
probabilites, not just those in the specified `support`. However,
314+
these probabilities are always zero (see example below).
315+
316+
If `probs` has size `(C, n1, n2, ..., nk)` then an array of size `(n1,
317+
n2, ..., nk)` is created. In all cases elements along the first axis
318+
always sum to one.
319+
310320
```
311321
using CategoricalArrays
312322
v = categorical([:x, :x, :y, :x, :z])
313323
314324
julia> UnivariateFinite(classes(v), [0.2, 0.3, 0.5])
315-
UnivariateFinite(x=>0.2, y=>0.3, z=>0.5) (Multiclass{3} samples)
325+
UnivariateFinite{Multiclass{3}}(x=>0.2, y=>0.3, z=>0.5)
316326
317327
julia> d = UnivariateFinite([v[1], v[end]], [0.1, 0.9])
318-
UnivariateFinite(x=>0.1, z=>0.9) (Multiclass{3} samples)
328+
UnivariateFiniteMulticlass{3}(x=>0.1, z=>0.9)
329+
330+
julia> rand(d, 3)
331+
3-element Array{Any,1}:
332+
CategoricalArrays.CategoricalValue{Symbol,UInt32} :z
333+
CategoricalArrays.CategoricalValue{Symbol,UInt32} :z
334+
CategoricalArrays.CategoricalValue{Symbol,UInt32} :z
335+
336+
julia> levels(d)
337+
3-element Array{Symbol,1}:
338+
:x
339+
:y
340+
:z
319341
320342
julia> pdf(d, :y)
321343
0.0
322-
323344
```
324345
325346
Alternatively, `support` may be a list of raw (non-categorical)
326347
elements if `pool` is:
327348
328-
- some `v::CategoricalVector` such that `support` is a subset of
329-
`levels(v)`
330-
331-
- some `a::CategoricalValue` such that `support` is a subset of
332-
`levels(a)`
333-
334-
- some `CategoricalPool` object
349+
- some `CategoricalArray`, `CategoricalValue` or `CategoricalPool`,
350+
such that `support` is a subset of `levels(pool)`
335351
336352
- `missing`, in which case a new categorical pool is created which has
337353
`support` as its only levels.
@@ -341,67 +357,47 @@ considered ordered.
341357
342358
```
343359
julia> UnivariateFinite([:x, :z], [0.1, 0.9], pool=missing, ordered=true)
344-
UnivariateFinite(x=>0.1, z=>0.9) (OrderedFactor{2} samples)
360+
UnivariateFinite{OrderedFactor{2}}(x=>0.1, z=>0.9)
345361
346362
julia> d = UnivariateFinite([:x, :z], [0.1, 0.9], pool=v) # v defined above
347363
UnivariateFinite(x=>0.1, z=>0.9) (Multiclass{3} samples)
348364
349365
julia> pdf(d, :y) # allowed as `:y in levels(v)`
350366
0.0
367+
368+
v = categorical([:x, :x, :y, :x, :z, :w])
369+
probs = rand(3, 100)
370+
probs = probs ./ sum(probs, dims=1)
371+
julia> UnivariateFinite([:x, :y, :z], probs, pool=v)
372+
100-element UnivariateFiniteVector{Multiclass{4},Symbol,UInt32,Float64}:
373+
UnivariateFinite{Multiclass{4}}(x=>0.194, y=>0.3, z=>0.505)
374+
UnivariateFinite{Multiclass{4}}(x=>0.727, y=>0.234, z=>0.0391)
375+
UnivariateFinite{Multiclass{4}}(x=>0.674, y=>0.00535, z=>0.321)
376+
377+
UnivariateFinite{Multiclass{4}}(x=>0.292, y=>0.339, z=>0.369)
351378
```
352379
380+
---
381+
353382
UnivariateFinite(prob_given_class; pool=nothing, ordered=false)
354383
355384
Construct a discrete univariate distribution whose finite support is
356385
the set of keys of the provided dictionary, `prob_given_class`, and
357386
whose values specify the corresponding probabilities.
358387
359388
The type requirements on the keys of the dictionary are the same as
360-
`support` above.
389+
the elements of `support` given above. If the values (probabilities)
390+
are arrays instead of scalars, then an abstract array of
391+
`UnivariateFinite` elements is created, with the same size as the
392+
array.
361393
362394
"""
395+
UNIVARIATE_FINITE_DOCSTRING
363396
UnivariateFinite(d::AbstractDict; kwargs...) =
364397
UnivariateFinite(get_interface_mode(), d; kwargs...)
365398
UnivariateFinite(support::AbstractVector, probs; kwargs...) =
366399
UnivariateFinite(get_interface_mode(), support, probs; kwargs...)
367400
UnivariateFinite(probs; kwargs...) =
368401
UnivariateFinite(get_interface_mode(), probs; kwargs...)
369-
370402
UnivariateFinite(::LightInterface, a...; kwargs...) =
371403
errlight("UnivariateFinite")
372-
373-
const UNIVARIATE_FINITE_VECTOR_DOCSTRING =
374-
"""
375-
UnivariateFiniteArray(support, probs; pool=nothing, ordered=false)
376-
377-
Construct a performant array of `UnivariateFinite` elements.
378-
379-
For an explanation of `support` and the keyword arguments, see
380-
[`UnivariateFinite`](@ref) . Here `probs` should be an array with
381-
`size(probs, 1) = C`, where `C = length(support)`, and its elements
382-
should sum to one along the first dimension.
383-
384-
In the special binary case `prob` may be a vector of arbitrary `Real`
385-
elements between 0 and 1, signifying the probabilities of the first
386-
element of `support`.
387-
388-
```
389-
using CategoricalArrays
390-
v = categorical([:x, :x, :y, :x, :z, :w])
391-
p = rand(6, 3)
392-
p = p ./ sum(p, dims=2)
393-
UnivariateFiniteArray([v[1], v[3], v[5]], p)
394-
395-
UnivariateFiniteArray([:x, :z, :z], pool=missing, ordered=true)
396-
397-
```
398-
399-
"""
400-
UnivariateFiniteArray(probs::AbstractArray; kwargs...) =
401-
UnivariateFiniteArray(get_interface_mode(), probs; kwargs...)
402-
UnivariateFiniteArray(support::AbstractArray,
403-
probs::AbstractArray; kwargs...) =
404-
UnivariateFiniteArray(get_interface_mode(), support, probs; kwargs...)
405-
406-
UnivariateFiniteArray(::LightInterface, a...; kwargs...) =
407-
errlight("UnivariateFiniteArray")

test/data_utils.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,5 +230,4 @@ end
230230
setlight()
231231
@test_throws M.InterfaceError UnivariateFinite(Dict(2=>3,3=>4))
232232
@test_throws M.InterfaceError UnivariateFinite(randn(2), randn(2))
233-
@test_throws M.InterfaceError UnivariateFiniteArray(randn(2), randn(2))
234233
end

0 commit comments

Comments
 (0)