Skip to content

Commit 70c9439

Browse files
authored
Merge pull request #47 from alan-turing-institute/univariate2
UnivariateFinite arrays, Take II
2 parents ec5e745 + 485e382 commit 70c9439

File tree

3 files changed

+128
-57
lines changed

3 files changed

+128
-57
lines changed

src/MLJModelInterface.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ export LightInterface, FullInterface
1414
# MLJ model hierarchy
1515
export MLJType, Model, Supervised, Unsupervised,
1616
Probabilistic, Deterministic, Interval, Static,
17-
UnivariateFinite, UnivariateFiniteVector
17+
UnivariateFinite
1818

1919
# model constructor + metadata
2020
export @mlj_model, metadata_pkg, metadata_model

src/data_utils.jl

Lines changed: 127 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ categorical(::LightInterface, a...; kw...) = errlight("categorical")
2121
matrix(X; transpose=false)
2222
2323
If `X <: AbstractMatrix`, return `X` or `permutedims(X)` if `transpose=true`.
24-
If `X` is a Tables.jl compatible table source, convert `X` into a `Matrix`
25-
$REQUIRE.
24+
If `X` is a Tables.jl compatible table source, convert `X` into a `Matrix`.
25+
2626
"""
2727
matrix(X; kw...) = matrix(get_interface_mode(), vtrait(X), X; kw...)
2828

@@ -43,7 +43,7 @@ matrix(::LightInterface, ::Val{:table}, X; kw...) = errlight("matrix")
4343
4444
The positional integer of the `CategoricalString` or `CategoricalValue` `x`, in
4545
the ordering defined by the pool of `x`. The type of `int(x)` is the reference
46-
type of `x` $REQUIRE.
46+
type of `x`.
4747
4848
Not to be confused with `x.ref`, which is unchanged by reordering of the pool
4949
of `x`, but has the same type.
@@ -82,10 +82,11 @@ int(::LightInterface, x) = errlight("int")
8282
"""
8383
classes(x)
8484
85-
All the categorical elements with the same pool as `x` (including `x`),
86-
returned as a list, with an ordering consistent with the pool $REQUIRE.
87-
Here `x` has `CategoricalValue` or `CategoricalString` type, and `classes(x)`
88-
is a vector of the same eltype. Note that `x in classes(x)` is always true.
85+
All the categorical elements with the same pool as `x` (including
86+
`x`), returned as a list, with an ordering consistent with the pool.
87+
Here `x` has `CategoricalValue` or `CategoricalString` type, and
88+
`classes(x)` is a vector of the same eltype. Note that `x in
89+
classes(x)` is always true.
8990
9091
Not to be confused with `levels(x.pool)`. See the example below.
9192
@@ -160,10 +161,10 @@ istable(::Mode, ::Val{:table}) = true
160161
d = decoder(x)
161162
162163
A callable object for decoding the integer representation of a
163-
`CategoricalString` or `CategoricalValue` sharing the same pool as `x`
164-
$REQUIRE. (Here `x` is of one of these two types.) Specifically, one has
165-
`d(int(y)) == y` for all `y in classes(x)`. One can also call `d` on integer
166-
arrays, in which case `d` is broadcast over all elements.
164+
`CategoricalString` or `CategoricalValue` sharing the same pool as
165+
`x`. (Here `x` is of one of these two types.) Specifically, one has
166+
`d(int(y)) == y` for all `y in classes(x)`. One can also call `d` on
167+
integer arrays, in which case `d` is broadcast over all elements.
167168
168169
julia> v = categorical([:c, :b, :c, :a])
169170
julia> int(v)
@@ -190,17 +191,17 @@ decoder(::LightInterface, x) = errlight("decoder")
190191
"""
191192
table(columntable; prototype=nothing)
192193
193-
Convert a named tuple of vectors or tuples `columntable`, into a table of the
194-
"preferred sink type" of `prototype` $REQUIRE. This is often the type of
194+
Convert a named tuple of vectors or tuples `columntable`, into a table
195+
of the "preferred sink type" of `prototype`. This is often the type of
195196
`prototype` itself, when `prototype` is a sink; see the Tables.jl
196-
documentation. If `prototype` is not specified, then a named tuple of vectors
197-
is returned.
197+
documentation. If `prototype` is not specified, then a named tuple of
198+
vectors is returned.
198199
199200
table(A::AbstractMatrix; names=nothing, prototype=nothing)
200201
201-
Wrap an abstract matrix `A` as a Tables.jl compatible table with the specified
202-
column `names` (a tuple of symbols). If `names` are not specified,
203-
`names=(:x1, :x2, ..., :xn)` is used, where `n=size(A, 2)` $REQUIRE.
202+
Wrap an abstract matrix `A` as a Tables.jl compatible table with the
203+
specified column `names` (a tuple of symbols). If `names` are not
204+
specified, `names=(:x1, :x2, ..., :xn)` is used, where `n=size(A, 2)`.
204205
205206
If a `prototype` is specified, then the matrix is materialized as a table of
206207
the preferred sink type of `prototype`, rather than wrapped. Note that if
@@ -216,7 +217,7 @@ table(::LightInterface, X; kw...) = errlight("table")
216217
"""
217218
nrows(X)
218219
219-
Return the number of rows for a table, abstract vector or matrix `X` $REQUIRE.
220+
Return the number of rows for a table, abstract vector or matrix `X`.
220221
"""
221222
nrows(X) = nrows(get_interface_mode(), vtrait(X), X)
222223

@@ -233,9 +234,11 @@ nrows(::LightInterface, ::Val{:table}, X) = errlight("table")
233234
"""
234235
selectrows(X, r)
235236
236-
Select single or multiple rows from a table, abstract vector or matrix `X`
237-
$REQUIRE. If `X` is tabular, the object returned is a table of the
238-
preferred sink type of `typeof(X)`, even if only a single row is selected.
237+
Select single or multiple rows from a table, abstract vector or matrix
238+
`X`. If `X` is tabular, the object returned is a table of the
239+
preferred sink type of `typeof(X)`, even if only a single row is
240+
selected.
241+
239242
"""
240243
selectrows(X, r) = selectrows(get_interface_mode(), vtrait(X), X, r)
241244

@@ -259,10 +262,11 @@ selectrows(::LightInterface, ::Val{:table}, X, r; kw...) =
259262
"""
260263
selectcols(X, c)
261264
262-
Select single or multiple columns from a matrix or table `X` $REQUIRE. If `c`
265+
Select single or multiple columns from a matrix or table `X`. If `c`
263266
is an abstract vector of integers or symbols, then the object returned
264267
is a table of the preferred sink type of `typeof(X)`. If `c` is a
265268
*single* integer or column, then an `AbstractVector` is returned.
269+
266270
"""
267271
selectcols(X, c) = selectcols(get_interface_mode(), vtrait(X), X, c)
268272

@@ -306,56 +310,124 @@ _squeeze(v) = first(v)
306310

307311
const UNIVARIATE_FINITE_DOCSTRING =
308312
"""
309-
UnivariateFinite(classes, p; pool=nothing, ordered=false)
313+
UnivariateFinite(support,
314+
probs;
315+
pool=nothing,
316+
augmented=false,
317+
ordered=false)
310318
311319
Construct a discrete univariate distribution whose finite support is
312-
the elements of the vector `classes`, and whose corresponding
313-
probabilities are elements of the vector `p`, which must sum to one $REQUIRE.
320+
the elements of the vector `support`, and whose corresponding
321+
probabilities are elements of the vector `probs`. More generally,
322+
construct an abstract *array* of `UnivariateFinite` distributions by
323+
choosing `probs` to be an array of one higher dimension than the array
324+
generated.
314325
315-
*Important.* Here `classes` must have type
326+
Unless `pool` is specified, `support` should have type
316327
`AbstractVector{<:CategoricalValue}` and all elements are assumed to
317-
share the same categorical pool. Raw classes *may* be used, but only provided
318-
`pool` is specified. The possible values are:
328+
share the same categorical pool, which may be larger than `support`.
329+
330+
*Important.* All levels of the common pool have associated
331+
probabilites, not just those in the specified `support`. However,
332+
these probabilities are always zero (see example below).
333+
334+
If `probs` is a matrix, it should have a column for each class in
335+
`support` (or one less, if `augment=true`). More generally, `probs`
336+
will be an array whose size is of the form `(n1, n2, ..., nk, c)`,
337+
where `c = length(suppport)` (or one less, if `augment=true`) and the
338+
constructor then returns an array of size `(n1, n2, ..., nk)`.
339+
340+
```
341+
using CategoricalArrays
342+
v = categorical([:x, :x, :y, :x, :z])
343+
344+
julia> UnivariateFinite(classes(v), [0.2, 0.3, 0.5])
345+
UnivariateFinite{Multiclass{3}}(x=>0.2, y=>0.3, z=>0.5)
319346
320-
- some `v::CategoricalVector` such that `classes` is a subset of `levels(v)`
347+
julia> d = UnivariateFinite([v[1], v[end]], [0.1, 0.9])
348+
UnivariateFiniteMulticlass{3}(x=>0.1, z=>0.9)
321349
322-
- some `a::CategoricalValue` such that `classes` is a subset of `levels(a)`
350+
julia> rand(d, 3)
351+
3-element Array{Any,1}:
352+
CategoricalArrays.CategoricalValue{Symbol,UInt32} :z
353+
CategoricalArrays.CategoricalValue{Symbol,UInt32} :z
354+
CategoricalArrays.CategoricalValue{Symbol,UInt32} :z
355+
356+
julia> levels(d)
357+
3-element Array{Symbol,1}:
358+
:x
359+
:y
360+
:z
361+
362+
julia> pdf(d, :y)
363+
0.0
364+
```
365+
366+
### Specifying a pool
367+
368+
Alternatively, `support` may be a list of raw (non-categorical)
369+
elements if `pool` is:
370+
371+
- some `CategoricalArray`, `CategoricalValue` or `CategoricalPool`,
372+
such that `support` is a subset of `levels(pool)`
323373
324374
- `missing`, in which case a new categorical pool is created which has
325-
`classes` as its only levels.
375+
`support` as its only levels.
376+
377+
In the last case, specify `ordered=true` if the pool is to be
378+
considered ordered.
326379
327-
In the last case specify `ordered=true` to order the new pool.
380+
```
381+
julia> UnivariateFinite([:x, :z], [0.1, 0.9], pool=missing, ordered=true)
382+
UnivariateFinite{OrderedFactor{2}}(x=>0.1, z=>0.9)
328383
329-
UnivariateFinite(prob_given_class; pool=nothing, ordered=false)
384+
julia> d = UnivariateFinite([:x, :z], [0.1, 0.9], pool=v) # v defined above
385+
UnivariateFinite(x=>0.1, z=>0.9) (Multiclass{3} samples)
386+
387+
julia> pdf(d, :y) # allowed as `:y in levels(v)`
388+
0.0
389+
390+
v = categorical([:x, :x, :y, :x, :z, :w])
391+
probs = rand(3, 100)
392+
probs = probs ./ sum(probs, dims=1)
393+
julia> UnivariateFinite([:x, :y, :z], probs, pool=v)
394+
100-element UnivariateFiniteVector{Multiclass{4},Symbol,UInt32,Float64}:
395+
UnivariateFinite{Multiclass{4}}(x=>0.194, y=>0.3, z=>0.505)
396+
UnivariateFinite{Multiclass{4}}(x=>0.727, y=>0.234, z=>0.0391)
397+
UnivariateFinite{Multiclass{4}}(x=>0.674, y=>0.00535, z=>0.321)
398+
399+
UnivariateFinite{Multiclass{4}}(x=>0.292, y=>0.339, z=>0.369)
400+
```
401+
402+
### Probability augmentation
403+
404+
Unless `augment=true`, sums of elements along the last axis (row-sums
405+
in the case of a matrix) must be equal to one, and otherwise such an
406+
array is created by inserting appropriate elements *ahead* of those
407+
provided. This means the provided probabilities are associated with
408+
the the classes `c2, c3, ..., cn`.
409+
410+
---
411+
412+
UnivariateFinite(prob_given_class; pool=nothing, ordered=false)
330413
331414
Construct a discrete univariate distribution whose finite support is
332415
the set of keys of the provided dictionary, `prob_given_class`, and
333-
whose values specify the corresponding probabilities $REQUIRE.
416+
whose values specify the corresponding probabilities.
334417
335418
The type requirements on the keys of the dictionary are the same as
336-
`classes` above.
419+
the elements of `support` given above. If the values (probabilities)
420+
are arrays instead of scalars, then an abstract array of
421+
`UnivariateFinite` elements is created, with the same size as the
422+
array.
337423
338424
"""
425+
UNIVARIATE_FINITE_DOCSTRING
339426
UnivariateFinite(d::AbstractDict; kwargs...) =
340427
UnivariateFinite(get_interface_mode(), d; kwargs...)
341-
UnivariateFinite(c::AbstractVector, p; kwargs...) =
342-
UnivariateFinite(get_interface_mode(), c, p; kwargs...)
343-
428+
UnivariateFinite(support::AbstractVector, probs; kwargs...) =
429+
UnivariateFinite(get_interface_mode(), support, probs; kwargs...)
430+
UnivariateFinite(probs; kwargs...) =
431+
UnivariateFinite(get_interface_mode(), probs; kwargs...)
344432
UnivariateFinite(::LightInterface, a...; kwargs...) =
345433
errlight("UnivariateFinite")
346-
347-
const UNIVARIATE_FINITE_VECTOR_DOCSTRING =
348-
"""
349-
UnivariateFiniteVector(scores, classes)
350-
351-
Container for UnivariateFinite elements optimised for efficiency.
352-
Accessing a single element will construct and return the corresponding
353-
UnivariateFinite lazily.
354-
"""
355-
UnivariateFiniteVector(s::AbstractArray) =
356-
UnivariateFiniteVector(get_interface_mode(), s)
357-
UnivariateFiniteVector(s::AbstractArray, c) =
358-
UnivariateFiniteVector(get_interface_mode(), s, c)
359-
360-
UnivariateFiniteVector(::LightInterface, a...) =
361-
errlight("UnivariateFiniteVector")

test/data_utils.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,5 +240,4 @@ end
240240
setlight()
241241
@test_throws M.InterfaceError UnivariateFinite(Dict(2=>3,3=>4))
242242
@test_throws M.InterfaceError UnivariateFinite(randn(2), randn(2))
243-
@test_throws M.InterfaceError UnivariateFiniteVector(randn(2), randn(2))
244243
end

0 commit comments

Comments
 (0)