Skip to content

Commit 09e821e

Browse files
committed
UnivariateFiniteVector -> UnivariateFiniteArray (not tested)
1 parent 7d53614 commit 09e821e

File tree

3 files changed

+108
-55
lines changed

3 files changed

+108
-55
lines changed

src/MLJModelInterface.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ export LightInterface, FullInterface
1414
# MLJ model hierarchy
1515
export MLJType, Model, Supervised, Unsupervised,
1616
Probabilistic, Deterministic, Interval, Static,
17-
UnivariateFinite, UnivariateFiniteVector
17+
UnivariateFinite, UnivariateFiniteArray
1818

1919
# model constructor + metadata
2020
export @mlj_model, metadata_pkg, metadata_model

src/data_utils.jl

Lines changed: 106 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ categorical(::LightInterface, a...; kw...) = errlight("categorical")
2121
matrix(X; transpose=false)
2222
2323
If `X <: AbstractMatrix`, return `X` or `permutedims(X)` if `transpose=true`.
24-
If `X` is a Tables.jl compatible table source, convert `X` into a `Matrix`
25-
$REQUIRE.
24+
If `X` is a Tables.jl compatible table source, convert `X` into a `Matrix`.
25+
2626
"""
2727
matrix(X; kw...) = matrix(get_interface_mode(), vtrait(X), X; kw...)
2828

@@ -43,7 +43,7 @@ matrix(::LightInterface, ::Val{:table}, X; kw...) = errlight("matrix")
4343
4444
The positional integer of the `CategoricalString` or `CategoricalValue` `x`, in
4545
the ordering defined by the pool of `x`. The type of `int(x)` is the reference
46-
type of `x` $REQUIRE.
46+
type of `x`.
4747
4848
Not to be confused with `x.ref`, which is unchanged by reordering of the pool
4949
of `x`, but has the same type.
@@ -82,10 +82,11 @@ int(::LightInterface, x) = errlight("int")
8282
"""
8383
classes(x)
8484
85-
All the categorical elements with the same pool as `x` (including `x`),
86-
returned as a list, with an ordering consistent with the pool $REQUIRE.
87-
Here `x` has `CategoricalValue` or `CategoricalString` type, and `classes(x)`
88-
is a vector of the same eltype. Note that `x in classes(x)` is always true.
85+
All the categorical elements with the same pool as `x` (including
86+
`x`), returned as a list, with an ordering consistent with the pool.
87+
Here `x` has `CategoricalValue` or `CategoricalString` type, and
88+
`classes(x)` is a vector of the same eltype. Note that `x in
89+
classes(x)` is always true.
8990
9091
Not to be confused with `levels(x.pool)`. See the example below.
9192
@@ -146,10 +147,10 @@ schema(::LightInterface, ::Val{:table}, X; kw...) = errlight("schema")
146147
d = decoder(x)
147148
148149
A callable object for decoding the integer representation of a
149-
`CategoricalString` or `CategoricalValue` sharing the same pool as `x`
150-
$REQUIRE. (Here `x` is of one of these two types.) Specifically, one has
151-
`d(int(y)) == y` for all `y in classes(x)`. One can also call `d` on integer
152-
arrays, in which case `d` is broadcast over all elements.
150+
`CategoricalString` or `CategoricalValue` sharing the same pool as
151+
`x`. (Here `x` is of one of these two types.) Specifically, one has
152+
`d(int(y)) == y` for all `y in classes(x)`. One can also call `d` on
153+
integer arrays, in which case `d` is broadcast over all elements.
153154
154155
julia> v = categorical([:c, :b, :c, :a])
155156
julia> int(v)
@@ -176,17 +177,17 @@ decoder(::LightInterface, x) = errlight("decoder")
176177
"""
177178
table(columntable; prototype=nothing)
178179
179-
Convert a named tuple of vectors or tuples `columntable`, into a table of the
180-
"preferred sink type" of `prototype` $REQUIRE. This is often the type of
180+
Convert a named tuple of vectors or tuples `columntable`, into a table
181+
of the "preferred sink type" of `prototype`. This is often the type of
181182
`prototype` itself, when `prototype` is a sink; see the Tables.jl
182-
documentation. If `prototype` is not specified, then a named tuple of vectors
183-
is returned.
183+
documentation. If `prototype` is not specified, then a named tuple of
184+
vectors is returned.
184185
185186
table(A::AbstractMatrix; names=nothing, prototype=nothing)
186187
187-
Wrap an abstract matrix `A` as a Tables.jl compatible table with the specified
188-
column `names` (a tuple of symbols). If `names` are not specified,
189-
`names=(:x1, :x2, ..., :xn)` is used, where `n=size(A, 2)` $REQUIRE.
188+
Wrap an abstract matrix `A` as a Tables.jl compatible table with the
189+
specified column `names` (a tuple of symbols). If `names` are not
190+
specified, `names=(:x1, :x2, ..., :xn)` is used, where `n=size(A, 2)`.
190191
191192
If a `prototype` is specified, then the matrix is materialized as a table of
192193
the preferred sink type of `prototype`, rather than wrapped. Note that if
@@ -202,7 +203,7 @@ table(::LightInterface, X; kw...) = errlight("table")
202203
"""
203204
nrows(X)
204205
205-
Return the number of rows for a table, abstract vector or matrix `X` $REQUIRE.
206+
Return the number of rows for a table, abstract vector or matrix `X`.
206207
"""
207208
nrows(X) = nrows(get_interface_mode(), vtrait(X), X)
208209

@@ -219,9 +220,11 @@ nrows(::LightInterface, ::Val{:table}, X) = errlight("table")
219220
"""
220221
selectrows(X, r)
221222
222-
Select single or multiple rows from a table, abstract vector or matrix `X`
223-
$REQUIRE. If `X` is tabular, the object returned is a table of the
224-
preferred sink type of `typeof(X)`, even if only a single row is selected.
223+
Select single or multiple rows from a table, abstract vector or matrix
224+
`X`. If `X` is tabular, the object returned is a table of the
225+
preferred sink type of `typeof(X)`, even if only a single row is
226+
selected.
227+
225228
"""
226229
selectrows(X, r) = selectrows(get_interface_mode(), vtrait(X), X, r)
227230

@@ -245,10 +248,11 @@ selectrows(::LightInterface, ::Val{:table}, X, r; kw...) =
245248
"""
246249
selectcols(X, c)
247250
248-
Select single or multiple columns from a matrix or table `X` $REQUIRE. If `c`
251+
Select single or multiple columns from a matrix or table `X`. If `c`
249252
is an abstract vector of integers or symbols, then the object returned
250253
is a table of the preferred sink type of `typeof(X)`. If `c` is a
251254
*single* integer or column, then an `AbstractVector` is returned.
255+
252256
"""
253257
selectcols(X, c) = selectcols(get_interface_mode(), vtrait(X), X, c)
254258

@@ -292,63 +296,112 @@ _squeeze(v) = first(v)
292296

293297
const UNIVARIATE_FINITE_DOCSTRING =
294298
"""
295-
UnivariateFinite(classes, p; pool=nothing, ordered=false)
299+
UnivariateFinite(support, probs; pool=nothing, ordered=false)
296300
297301
Construct a discrete univariate distribution whose finite support is
298-
the elements of the vector `classes`, and whose corresponding
299-
probabilities are elements of the vector `p`, which must sum to one $REQUIRE.
302+
the elements of the vector `support`, and whose corresponding
303+
probabilities are elements of the vector `probs`, which must sum to
304+
one.
300305
301-
*Important.* Here `classes` must have type
306+
Unless `pool` is specified, `support` must have type
302307
`AbstractVector{<:CategoricalValue}` and all elements are assumed to
303-
share the same categorical pool. Raw classes *may* be used, but only provided
304-
`pool` is specified. The possible values are:
308+
share the same categorical pool.
309+
310+
```
311+
using CategoricalArrays
312+
v = categorical([:x, :x, :y, :x, :z])
313+
314+
julia> UnivariateFinite(classes(v), [0.2, 0.3, 0.5])
315+
UnivariateFinite(x=>0.2, y=>0.3, z=>0.5) (Multiclass{3} samples)
316+
317+
julia> d = UnivariateFinite([v[1], v[end]], [0.1, 0.9])
318+
UnivariateFinite(x=>0.1, z=>0.9) (Multiclass{3} samples)
319+
320+
julia> pdf(d, :y)
321+
0.0
322+
323+
```
324+
325+
Alternatively, `support` may be a list of raw (non-categorical)
326+
elements if `pool` is:
305327
306-
- some `v::CategoricalVector` such that `classes` is a subset of `levels(v)`
328+
- some `v::CategoricalVector` such that `support` is a subset of
329+
`levels(v)`
307330
308-
- some `a::CategoricalValue` such that `classes` is a subset of `levels(a)`
331+
- some `a::CategoricalValue` such that `support` is a subset of
332+
`levels(a)`
333+
334+
- some `CategoricalPool` object
309335
310336
- `missing`, in which case a new categorical pool is created which has
311-
`classes` as its only levels.
337+
`support` as its only levels.
338+
339+
In the last case, specify `ordered=true` if the pool is to be
340+
considered ordered.
341+
342+
```
343+
julia> UnivariateFinite([:x, :z], [0.1, 0.9], pool=missing, ordered=true)
344+
UnivariateFinite(x=>0.1, z=>0.9) (OrderedFactor{2} samples)
345+
346+
julia> d = UnivariateFinite([:x, :z], [0.1, 0.9], pool=v) # v defined above
347+
UnivariateFinite(x=>0.1, z=>0.9) (Multiclass{3} samples)
312348
313-
In the last case specify `ordered=true` to order the new pool.
349+
julia> pdf(d, :y) # allowed as `:y in levels(v)`
350+
0.0
351+
```
314352
315-
UnivariateFinite(prob_given_class; pool=nothing, ordered=false)
353+
UnivariateFinite(prob_given_class; pool=nothing, ordered=false)
316354
317355
Construct a discrete univariate distribution whose finite support is
318356
the set of keys of the provided dictionary, `prob_given_class`, and
319-
whose values specify the corresponding probabilities $REQUIRE.
357+
whose values specify the corresponding probabilities.
320358
321359
The type requirements on the keys of the dictionary are the same as
322-
`classes` above.
360+
`support` above.
323361
324362
"""
325363
UnivariateFinite(d::AbstractDict; kwargs...) =
326364
UnivariateFinite(get_interface_mode(), d; kwargs...)
327-
UnivariateFinite(c::AbstractVector, p; kwargs...) =
328-
UnivariateFinite(get_interface_mode(), c, p; kwargs...)
365+
UnivariateFinite(support::AbstractVector, probs; kwargs...) =
366+
UnivariateFinite(get_interface_mode(), support, probs; kwargs...)
367+
UnivariateFinite(probs; kwargs...) =
368+
UnivariateFinite(get_interface_mode(), probs; kwargs...)
329369

330370
UnivariateFinite(::LightInterface, a...; kwargs...) =
331371
errlight("UnivariateFinite")
332372

333373
const UNIVARIATE_FINITE_VECTOR_DOCSTRING =
334374
"""
335-
UnivariateFiniteVector(classes, p; pool=nothing, ordered=false)
375+
UnivariateFiniteArray(support, probs; pool=nothing, ordered=false)
336376
337-
Container for UnivariateFinite elements optimised for efficiency.
338-
Accessing a single element will construct and return the corresponding
339-
UnivariateFinite lazily.
377+
Construct a performant array of `UnivariateFinite` elements.
340378
341-
Here the probabalities `p` should be an array with `size(p, 2) = N`,
342-
where `N = length(classes)` and rows sum to one.
379+
For an explanation of `support` and the keyword arguments, see
380+
[`UnivariateFinite`](@ref) . Here `probs` should be an array with
381+
`size(probs, 1) = C`, where `C = length(support)`, and its elements
382+
should sum to one along the first dimension.
343383
344-
See [`UnivariateFinite`](@ref) for explanation of the `pool` and
345-
`ordered` key-word arguments.
384+
In the special binary case `prob` may be a vector of arbitrary `Real`
385+
elements between 0 and 1, signifying the probabilities of the first
386+
element of `support`.
346387
347-
"""
348-
UnivariateFiniteVector(s::AbstractArray; kwargs...) =
349-
UnivariateFiniteVector(get_interface_mode(), s; kwargs...)
350-
UnivariateFiniteVector(c::AbstractArray, s::AbstractArray; kwargs...) =
351-
UnivariateFiniteVector(get_interface_mode(), c, s; kwargs...)
388+
```
389+
using CategoricalArrays
390+
v = categorical([:x, :x, :y, :x, :z, :w])
391+
p = rand(6, 3)
392+
p = p ./ sum(p, dims=2)
393+
UnivariateFiniteArray([v[1], v[3], v[5]], p)
394+
395+
UnivariateFiniteArray([:x, :z, :z], pool=missing, ordered=true)
352396
353-
UnivariateFiniteVector(::LightInterface, a...; kwargs...) =
354-
errlight("UnivariateFiniteVector")
397+
```
398+
399+
"""
400+
UnivariateFiniteArray(probs::AbstractArray; kwargs...) =
401+
UnivariateFiniteArray(get_interface_mode(), probs; kwargs...)
402+
UnivariateFiniteArray(support::AbstractArray,
403+
probs::AbstractArray; kwargs...) =
404+
UnivariateFiniteArray(get_interface_mode(), support, probs; kwargs...)
405+
406+
UnivariateFiniteArray(::LightInterface, a...; kwargs...) =
407+
errlight("UnivariateFiniteArray")

test/data_utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,5 +230,5 @@ end
230230
setlight()
231231
@test_throws M.InterfaceError UnivariateFinite(Dict(2=>3,3=>4))
232232
@test_throws M.InterfaceError UnivariateFinite(randn(2), randn(2))
233-
@test_throws M.InterfaceError UnivariateFiniteVector(randn(2), randn(2))
233+
@test_throws M.InterfaceError UnivariateFiniteArray(randn(2), randn(2))
234234
end

0 commit comments

Comments
 (0)