Skip to content

Commit 673761e

Browse files
authored
Merge pull request #46 from alan-turing-institute/dev
For a 0.2.7 release
2 parents b0ebe91 + 70c9439 commit 673761e

File tree

4 files changed

+157
-46
lines changed

4 files changed

+157
-46
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MLJModelInterface"
22
uuid = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
33
authors = ["Thibaut Lienart and Anthony Blaom"]
4-
version = "0.2.6"
4+
version = "0.2.7"
55

66
[deps]
77
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

src/data_utils.jl

Lines changed: 141 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ categorical(::LightInterface, a...; kw...) = errlight("categorical")
2121
matrix(X; transpose=false)
2222
2323
If `X <: AbstractMatrix`, return `X` or `permutedims(X)` if `transpose=true`.
24-
If `X` is a Tables.jl compatible table source, convert `X` into a `Matrix`
25-
$REQUIRE.
24+
If `X` is a Tables.jl compatible table source, convert `X` into a `Matrix`.
25+
2626
"""
2727
matrix(X; kw...) = matrix(get_interface_mode(), vtrait(X), X; kw...)
2828

@@ -43,7 +43,7 @@ matrix(::LightInterface, ::Val{:table}, X; kw...) = errlight("matrix")
4343
4444
The positional integer of the `CategoricalString` or `CategoricalValue` `x`, in
4545
the ordering defined by the pool of `x`. The type of `int(x)` is the reference
46-
type of `x` $REQUIRE.
46+
type of `x`.
4747
4848
Not to be confused with `x.ref`, which is unchanged by reordering of the pool
4949
of `x`, but has the same type.
@@ -82,10 +82,11 @@ int(::LightInterface, x) = errlight("int")
8282
"""
8383
classes(x)
8484
85-
All the categorical elements with the same pool as `x` (including `x`),
86-
returned as a list, with an ordering consistent with the pool $REQUIRE.
87-
Here `x` has `CategoricalValue` or `CategoricalString` type, and `classes(x)`
88-
is a vector of the same eltype. Note that `x in classes(x)` is always true.
85+
All the categorical elements with the same pool as `x` (including
86+
`x`), returned as a list, with an ordering consistent with the pool.
87+
Here `x` has `CategoricalValue` or `CategoricalString` type, and
88+
`classes(x)` is a vector of the same eltype. Note that `x in
89+
classes(x)` is always true.
8990
9091
Not to be confused with `levels(x.pool)`. See the example below.
9192
@@ -139,17 +140,31 @@ schema(::LightInterface, ::Val{:other}, X; kw...) = errlight("schema")
139140

140141
schema(::LightInterface, ::Val{:table}, X; kw...) = errlight("schema")
141142

143+
# ------------------------------------------------------------------------
144+
# istable
145+
146+
"""
147+
istable(X)
148+
149+
Return true if `X` is tabular.
150+
"""
151+
istable(X) = istable(get_interface_mode(), vtrait(X))
152+
153+
istable(::Mode, ::Val{:other}) = false
154+
155+
istable(::Mode, ::Val{:table}) = true
156+
142157
# ------------------------------------------------------------------------
143158
# decoder
144159

145160
"""
146161
d = decoder(x)
147162
148163
A callable object for decoding the integer representation of a
149-
`CategoricalString` or `CategoricalValue` sharing the same pool as `x`
150-
$REQUIRE. (Here `x` is of one of these two types.) Specifically, one has
151-
`d(int(y)) == y` for all `y in classes(x)`. One can also call `d` on integer
152-
arrays, in which case `d` is broadcast over all elements.
164+
`CategoricalString` or `CategoricalValue` sharing the same pool as
165+
`x`. (Here `x` is of one of these two types.) Specifically, one has
166+
`d(int(y)) == y` for all `y in classes(x)`. One can also call `d` on
167+
integer arrays, in which case `d` is broadcast over all elements.
153168
154169
julia> v = categorical([:c, :b, :c, :a])
155170
julia> int(v)
@@ -176,17 +191,17 @@ decoder(::LightInterface, x) = errlight("decoder")
176191
"""
177192
table(columntable; prototype=nothing)
178193
179-
Convert a named tuple of vectors or tuples `columntable`, into a table of the
180-
"preferred sink type" of `prototype` $REQUIRE. This is often the type of
194+
Convert a named tuple of vectors or tuples `columntable`, into a table
195+
of the "preferred sink type" of `prototype`. This is often the type of
181196
`prototype` itself, when `prototype` is a sink; see the Tables.jl
182-
documentation. If `prototype` is not specified, then a named tuple of vectors
183-
is returned.
197+
documentation. If `prototype` is not specified, then a named tuple of
198+
vectors is returned.
184199
185200
table(A::AbstractMatrix; names=nothing, prototype=nothing)
186201
187-
Wrap an abstract matrix `A` as a Tables.jl compatible table with the specified
188-
column `names` (a tuple of symbols). If `names` are not specified,
189-
`names=(:x1, :x2, ..., :xn)` is used, where `n=size(A, 2)` $REQUIRE.
202+
Wrap an abstract matrix `A` as a Tables.jl compatible table with the
203+
specified column `names` (a tuple of symbols). If `names` are not
204+
specified, `names=(:x1, :x2, ..., :xn)` is used, where `n=size(A, 2)`.
190205
191206
If a `prototype` is specified, then the matrix is materialized as a table of
192207
the preferred sink type of `prototype`, rather than wrapped. Note that if
@@ -202,7 +217,7 @@ table(::LightInterface, X; kw...) = errlight("table")
202217
"""
203218
nrows(X)
204219
205-
Return the number of rows for a table, abstract vector or matrix `X` $REQUIRE.
220+
Return the number of rows for a table, abstract vector or matrix `X`.
206221
"""
207222
nrows(X) = nrows(get_interface_mode(), vtrait(X), X)
208223

@@ -219,9 +234,11 @@ nrows(::LightInterface, ::Val{:table}, X) = errlight("table")
219234
"""
220235
selectrows(X, r)
221236
222-
Select single or multiple rows from a table, abstract vector or matrix `X`
223-
$REQUIRE. If `X` is tabular, the object returned is a table of the
224-
preferred sink type of `typeof(X)`, even if only a single row is selected.
237+
Select single or multiple rows from a table, abstract vector or matrix
238+
`X`. If `X` is tabular, the object returned is a table of the
239+
preferred sink type of `typeof(X)`, even if only a single row is
240+
selected.
241+
225242
"""
226243
selectrows(X, r) = selectrows(get_interface_mode(), vtrait(X), X, r)
227244

@@ -245,10 +262,11 @@ selectrows(::LightInterface, ::Val{:table}, X, r; kw...) =
245262
"""
246263
selectcols(X, c)
247264
248-
Select single or multiple columns from a matrix or table `X` $REQUIRE. If `c`
265+
Select single or multiple columns from a matrix or table `X`. If `c`
249266
is an abstract vector of integers or symbols, then the object returned
250267
is a table of the preferred sink type of `typeof(X)`. If `c` is a
251268
*single* integer or column, then an `AbstractVector` is returned.
269+
252270
"""
253271
selectcols(X, c) = selectcols(get_interface_mode(), vtrait(X), X, c)
254272

@@ -292,40 +310,124 @@ _squeeze(v) = first(v)
292310

293311
const UNIVARIATE_FINITE_DOCSTRING =
294312
"""
295-
UnivariateFinite(classes, p; pool=nothing, ordered=false)
313+
UnivariateFinite(support,
314+
probs;
315+
pool=nothing,
316+
augmented=false,
317+
ordered=false)
296318
297319
Construct a discrete univariate distribution whose finite support is
298-
the elements of the vector `classes`, and whose corresponding
299-
probabilities are elements of the vector `p`, which must sum to one $REQUIRE.
320+
the elements of the vector `support`, and whose corresponding
321+
probabilities are elements of the vector `probs`. More generally,
322+
construct an abstract *array* of `UnivariateFinite` distributions by
323+
choosing `probs` to be an array of one higher dimension than the array
324+
generated.
300325
301-
*Important.* Here `classes` must have type
326+
Unless `pool` is specified, `support` should have type
302327
`AbstractVector{<:CategoricalValue}` and all elements are assumed to
303-
share the same categorical pool. Raw classes *may* be used, but only provided
304-
`pool` is specified. The possible values are:
328+
share the same categorical pool, which may be larger than `support`.
329+
330+
*Important.* All levels of the common pool have associated
331+
probabilites, not just those in the specified `support`. However,
332+
these probabilities are always zero (see example below).
333+
334+
If `probs` is a matrix, it should have a column for each class in
335+
`support` (or one less, if `augment=true`). More generally, `probs`
336+
will be an array whose size is of the form `(n1, n2, ..., nk, c)`,
337+
where `c = length(suppport)` (or one less, if `augment=true`) and the
338+
constructor then returns an array of size `(n1, n2, ..., nk)`.
339+
340+
```
341+
using CategoricalArrays
342+
v = categorical([:x, :x, :y, :x, :z])
305343
306-
- some `v::CategoricalVector` such that `classes` is a subset of `levels(v)`
344+
julia> UnivariateFinite(classes(v), [0.2, 0.3, 0.5])
345+
UnivariateFinite{Multiclass{3}}(x=>0.2, y=>0.3, z=>0.5)
307346
308-
- some `a::CategoricalValue` such that `classes` is a subset of `levels(a)`
347+
julia> d = UnivariateFinite([v[1], v[end]], [0.1, 0.9])
348+
UnivariateFiniteMulticlass{3}(x=>0.1, z=>0.9)
349+
350+
julia> rand(d, 3)
351+
3-element Array{Any,1}:
352+
CategoricalArrays.CategoricalValue{Symbol,UInt32} :z
353+
CategoricalArrays.CategoricalValue{Symbol,UInt32} :z
354+
CategoricalArrays.CategoricalValue{Symbol,UInt32} :z
355+
356+
julia> levels(d)
357+
3-element Array{Symbol,1}:
358+
:x
359+
:y
360+
:z
361+
362+
julia> pdf(d, :y)
363+
0.0
364+
```
365+
366+
### Specifying a pool
367+
368+
Alternatively, `support` may be a list of raw (non-categorical)
369+
elements if `pool` is:
370+
371+
- some `CategoricalArray`, `CategoricalValue` or `CategoricalPool`,
372+
such that `support` is a subset of `levels(pool)`
309373
310374
- `missing`, in which case a new categorical pool is created which has
311-
`classes` as its only levels.
375+
`support` as its only levels.
376+
377+
In the last case, specify `ordered=true` if the pool is to be
378+
considered ordered.
312379
313-
In the last case specify `ordered=true` to order the new pool.
380+
```
381+
julia> UnivariateFinite([:x, :z], [0.1, 0.9], pool=missing, ordered=true)
382+
UnivariateFinite{OrderedFactor{2}}(x=>0.1, z=>0.9)
314383
315-
UnivariateFinite(prob_given_class; pool=nothing, ordered=false)
384+
julia> d = UnivariateFinite([:x, :z], [0.1, 0.9], pool=v) # v defined above
385+
UnivariateFinite(x=>0.1, z=>0.9) (Multiclass{3} samples)
386+
387+
julia> pdf(d, :y) # allowed as `:y in levels(v)`
388+
0.0
389+
390+
v = categorical([:x, :x, :y, :x, :z, :w])
391+
probs = rand(3, 100)
392+
probs = probs ./ sum(probs, dims=1)
393+
julia> UnivariateFinite([:x, :y, :z], probs, pool=v)
394+
100-element UnivariateFiniteVector{Multiclass{4},Symbol,UInt32,Float64}:
395+
UnivariateFinite{Multiclass{4}}(x=>0.194, y=>0.3, z=>0.505)
396+
UnivariateFinite{Multiclass{4}}(x=>0.727, y=>0.234, z=>0.0391)
397+
UnivariateFinite{Multiclass{4}}(x=>0.674, y=>0.00535, z=>0.321)
398+
399+
UnivariateFinite{Multiclass{4}}(x=>0.292, y=>0.339, z=>0.369)
400+
```
401+
402+
### Probability augmentation
403+
404+
Unless `augment=true`, sums of elements along the last axis (row-sums
405+
in the case of a matrix) must be equal to one, and otherwise such an
406+
array is created by inserting appropriate elements *ahead* of those
407+
provided. This means the provided probabilities are associated with
408+
the the classes `c2, c3, ..., cn`.
409+
410+
---
411+
412+
UnivariateFinite(prob_given_class; pool=nothing, ordered=false)
316413
317414
Construct a discrete univariate distribution whose finite support is
318415
the set of keys of the provided dictionary, `prob_given_class`, and
319-
whose values specify the corresponding probabilities $REQUIRE.
416+
whose values specify the corresponding probabilities.
320417
321418
The type requirements on the keys of the dictionary are the same as
322-
`classes` above.
419+
the elements of `support` given above. If the values (probabilities)
420+
are arrays instead of scalars, then an abstract array of
421+
`UnivariateFinite` elements is created, with the same size as the
422+
array.
323423
324424
"""
425+
UNIVARIATE_FINITE_DOCSTRING
325426
UnivariateFinite(d::AbstractDict; kwargs...) =
326427
UnivariateFinite(get_interface_mode(), d; kwargs...)
327-
UnivariateFinite(c::AbstractVector, p; kwargs...) =
328-
UnivariateFinite(get_interface_mode(), c, p; kwargs...)
329-
428+
UnivariateFinite(support::AbstractVector, probs; kwargs...) =
429+
UnivariateFinite(get_interface_mode(), support, probs; kwargs...)
430+
UnivariateFinite(probs; kwargs...) =
431+
UnivariateFinite(get_interface_mode(), probs; kwargs...)
330432
UnivariateFinite(::LightInterface, a...; kwargs...) =
331433
errlight("UnivariateFinite")

test/data_utils.jl

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,12 @@ end
3434
end
3535
@testset "int-full" begin
3636
setfull()
37-
M.int(::FI, x::CategoricalElement; kw...) =
38-
CategoricalArrays.order(x.pool)[x.level]
37+
M.int(::FI, x::CategoricalValue; kw...) =
38+
collect(1:length(levels(x.pool)))[x.level]
3939
x = categorical(['a','b','a'])
4040
@test int(x[1]) == 0x01
4141
@test int(x[2]) == 0x02
42-
@test int(x[2]) isa UInt32
42+
@test_broken int(x[2]) isa UInt32
4343
@test int(x[1], type=Int64) == 1
4444
@test int(x[1], type=Int64) isa Int64
4545
end
@@ -52,8 +52,8 @@ end
5252
@testset "classes-full" begin
5353
setfull()
5454
M.classes(::FI, p::CategoricalPool) =
55-
[p[i] for i in invperm(CategoricalArrays.order(p))]
56-
M.classes(::FI, x::CategoricalElement) = classes(x.pool)
55+
[p[i] for i in invperm(1:length(levels(p)))]
56+
M.classes(::FI, x::CategoricalValue) = classes(x.pool)
5757
x = categorical(['a','b','a'])
5858
@test classes(x[1]) == ['a', 'b']
5959
end
@@ -80,6 +80,16 @@ end
8080
@test sch.scitypes[2] <: Multiclass
8181
end
8282
# ------------------------------------------------------------------------
83+
@testset "istable" begin
84+
setlight()
85+
X = rand(5)
86+
@test !M.istable(X)
87+
X = randn(5,5)
88+
@test !M.istable(X)
89+
X = DataFrame(A=rand(10))
90+
@test M.istable(X)
91+
end
92+
# ------------------------------------------------------------------------
8393
@testset "decoder-light" begin
8494
setlight()
8595
x = 5

test/runtests.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ import MLJBase
66

77
const M = MLJModelInterface
88
const FI = M.FullInterface
9-
const CategoricalElement = Union{CategoricalValue,CategoricalString}
109
ScientificTypes.TRAIT_FUNCTION_GIVEN_NAME[:table] = Tables.istable
1110

1211
setlight() = M.set_interface_mode(M.LightInterface())

0 commit comments

Comments
 (0)