Skip to content

Commit c8588a9

Browse files
authored
data utils (#3)
1 parent 7fb9d6b commit c8588a9

16 files changed

+1271
-60
lines changed

Manifest.toml

Lines changed: 0 additions & 6 deletions
This file was deleted.

Project.toml

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,17 @@ version = "0.1.0"
77
ScientificTypes = "321657f4-b219-11e9-178b-2701a2544e81"
88

99
[compat]
10-
ScientificTypes = "^0.6"
10+
ScientificTypes = "^0.7"
1111
julia = "1"
1212

1313
[extras]
14+
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
15+
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
16+
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
17+
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
1418
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
1519
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1620

1721
[targets]
18-
test = ["Test", "Tables"]
22+
test = ["Test", "Tables", "Distances", "CategoricalArrays", "InteractiveUtils",
23+
"DataFrames"]

src/MLJModelInterface.jl

Lines changed: 55 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,43 @@ module MLJModelInterface
22

33
# ------------------------------------------------------------------------
44
# Dependency (note that ScientificTypes itself does not have dependencies)
5-
import ScientificTypes: trait
5+
using ScientificTypes
66

77
# ------------------------------------------------------------------------
8-
# Single export: matrix, everything else is qualified in MLJBase
9-
export matrix
8+
# exports
9+
10+
# types
11+
export LightInterface, FullInterface
12+
export MLJType, Model, Supervised, Unsupervised,
13+
Probabilistic, Deterministic, Interval, Static,
14+
UnivariateFinite
15+
16+
# rexport types from ScientificTypes
17+
export Scientific, Found, Unknown, Known, Finite, Infinite,
18+
OrderedFactor, Multiclass, Count, Continuous, Textual,
19+
Binary, ColorImage, GrayImage, Table
20+
21+
# constructor + metadata
22+
export @mlj_model, metadata_pkg, metadata_model, metadata_measure
23+
# api
24+
export fit, update, update_data, transform, inverse_transform,
25+
fitted_params, predict, predict_mode, predict_mean, predict_median,
26+
evaluate, clean!
27+
# traits
28+
export input_scitype, output_scitype, target_scitype,
29+
is_pure_julia, package_name, package_license,
30+
load_path, package_uuid, package_url,
31+
is_wrapper, supports_weights, supports_online,
32+
docstring, name, is_supervised,
33+
prediction_type, implemented_methods, hyperparameters,
34+
hyperparameter_types, hyperparameter_ranges
35+
36+
# data operations
37+
export matrix, int, classes, decoder, table,
38+
nrows, selectrows, selectcols, select
1039

1140
# ------------------------------------------------------------------------
41+
# Mode trick
1242

1343
abstract type Mode end
1444
struct LightInterface <: Mode end
@@ -24,24 +54,32 @@ struct InterfaceError <: Exception
2454
m::String
2555
end
2656

27-
vtrait(X) = X |> trait |> Val
57+
# ------------------------------------------------------------------------
58+
# Model types
59+
60+
abstract type MLJType end
61+
62+
abstract type Model <: MLJType end
63+
64+
abstract type Supervised <: Model end
65+
abstract type Unsupervised <: Model end
2866

29-
"""
30-
matrix(X; transpose=false)
67+
abstract type Probabilistic <: Supervised end
68+
abstract type Deterministic <: Supervised end
69+
abstract type Interval <: Supervised end
3170

32-
If `X <: AbstractMatrix`, return `X` or `permutedims(X)` if `transpose=true`.
33-
If `X` is a Tables.jl compatible table source, convert `X` into a `Matrix`.
34-
"""
35-
matrix(X; kw...) = matrix(vtrait(X), X, get_interface_mode(); kw...)
71+
abstract type Static <: Unsupervised end
72+
73+
# ------------------------------------------------------------------------
74+
# includes
3675

37-
matrix(::Val{:other}, X::AbstractMatrix, ::Mode; transpose=false) =
38-
transpose ? permutedims(X) : X
76+
include("utils.jl")
3977

40-
matrix(::Val{:other}, X, ::Mode; kw...) =
41-
throw(ArgumentError("Function `matrix` only supports AbstractMatrix or " *
42-
"containers implementing the Tables interface."))
78+
include("data_utils.jl")
79+
include("metadata_utils.jl")
4380

44-
matrix(::Val{:table}, X, ::LightInterface; kw...) =
45-
throw(InterfaceError("Only `MLJModelInterface` loaded. Import `MLJBase`."))
81+
include("model_traits.jl")
82+
include("model_def.jl")
83+
include("model_api.jl")
4684

4785
end # module

src/data_utils.jl

Lines changed: 276 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,276 @@
1+
2+
vtrait(X) = X |> trait |> Val
3+
4+
const REQUIRE = "(requires MLJBase to be loaded)"
5+
6+
errlight(s) = throw(InterfaceError("Only `MLJModelInterface` is loaded. " *
7+
"Import `MLJBase` in order to use `$s`."))
8+
9+
# ------------------------------------------------------------------------
10+
# matrix
11+
12+
"""
13+
matrix(X; transpose=false)
14+
15+
If `X <: AbstractMatrix`, return `X` or `permutedims(X)` if `transpose=true`.
16+
If `X` is a Tables.jl compatible table source, convert `X` into a `Matrix`
17+
$REQUIRE.
18+
"""
19+
matrix(X; kw...) = matrix(get_interface_mode(), vtrait(X), X; kw...)
20+
21+
matrix(::Mode, ::Val{:other}, X::AbstractMatrix; transpose=false) =
22+
transpose ? permutedims(X) : X
23+
24+
matrix(::Mode, ::Val{:other}, X; kw...) =
25+
throw(ArgumentError("Function `matrix` only supports AbstractMatrix or " *
26+
"containers implementing the Tables interface."))
27+
28+
matrix(::LightInterface, ::Val{:table}, X; kw...) = errlight("matrix")
29+
30+
# ------------------------------------------------------------------------
31+
# int
32+
33+
"""
34+
int(x)
35+
36+
The positional integer of the `CategoricalString` or `CategoricalValue` `x`, in
37+
the ordering defined by the pool of `x`. The type of `int(x)` is the reference
38+
type of `x` $REQUIRE.
39+
40+
Not to be confused with `x.ref`, which is unchanged by reordering of the pool
41+
of `x`, but has the same type.
42+
43+
int(X::CategoricalArray)
44+
int(W::Array{<:CategoricalString})
45+
int(W::Array{<:CategoricalValue})
46+
47+
Broadcasted versions of `int`.
48+
49+
julia> v = categorical([:c, :b, :c, :a])
50+
julia> levels(v)
51+
3-element Array{Symbol,1}:
52+
:a
53+
:b
54+
:c
55+
julia> int(v)
56+
4-element Array{UInt32,1}:
57+
0x00000003
58+
0x00000002
59+
0x00000003
60+
0x00000001
61+
62+
See also: [`decoder`](@ref).
63+
"""
64+
int(x; kw...) = int(get_interface_mode(), x; kw...)
65+
66+
int(::LightInterface, x; kw...) = errlight("int")
67+
68+
# ------------------------------------------------------------------------
69+
# classes
70+
71+
"""
72+
classes(x)
73+
74+
All the categorical elements with in the same pool as `x` (including `x`),
75+
returned as a list, with an ordering consistent with the pool $REQUIRE.
76+
Here `x` has `CategoricalValue` or `CategoricalString` type, and `classes(x)`
77+
is a vector of the same eltype. Note that `x in classes(x)` is always true.
78+
79+
Not to be confused with `levels(x.pool)`. See the example below.
80+
81+
julia> v = categorical([:c, :b, :c, :a])
82+
4-element CategoricalArrays.CategoricalArray{Symbol,1,UInt32}:
83+
:c
84+
:b
85+
:c
86+
:a
87+
88+
julia> levels(v)
89+
3-element Array{Symbol,1}:
90+
:a
91+
:b
92+
:c
93+
94+
julia> x = v[4]
95+
CategoricalArrays.CategoricalValue{Symbol,UInt32} :a
96+
97+
julia> classes(x)
98+
3-element CategoricalArrays.CategoricalArray{Symbol,1,UInt32}:
99+
:a
100+
:b
101+
:c
102+
103+
julia> levels(x.pool)
104+
3-element Array{Symbol,1}:
105+
:a
106+
:b
107+
:c
108+
109+
"""
110+
classes(x) = classes(get_interface_mode(), x)
111+
112+
classes(::LightInterface, x) = errlight("classes")
113+
114+
# ------------------------------------------------------------------------
115+
# decoder
116+
117+
"""
118+
d = decoder(x)
119+
120+
A callable object for decoding the integer representation of a
121+
`CategoricalString` or `CategoricalValue` sharing the same pool as `x`
122+
$REQUIRE. (Here `x` is of one of these two types.) Specifically, one has
123+
`d(int(y)) == y` for all `y in classes(x)`. One can also call `d` on integer
124+
arrays, in which case `d` is broadcast over all elements.
125+
126+
julia> v = categorical([:c, :b, :c, :a])
127+
julia> int(v)
128+
4-element Array{UInt32,1}:
129+
0x00000003
130+
0x00000002
131+
0x00000003
132+
0x00000001
133+
julia> d = decoder(v[3])
134+
julia> d(int(v)) == v
135+
true
136+
137+
*Warning:* It is *not* true that `int(d(u)) == u` always holds.
138+
139+
See also: [`int`](@ref), [`classes`](@ref).
140+
"""
141+
decoder(x) = decoder(get_interface_mode(), x)
142+
143+
decoder(::LightInterface, x) = errlight("decoder")
144+
145+
# ------------------------------------------------------------------------
146+
# table
147+
148+
"""
149+
table(columntable; prototype=nothing)
150+
151+
Convert a named tuple of vectors or tuples `columntable`, into a table of the
152+
"preferred sink type" of `prototype` $REQUIRE. This is often the type of
153+
`prototype` itself, when `prototype` is a sink; see the Tables.jl
154+
documentation. If `prototype` is not specified, then a named tuple of vectors
155+
is returned.
156+
157+
table(A::AbstractMatrix; names=nothing, prototype=nothing)
158+
159+
Wrap an abstract matrix `A` as a Tables.jl compatible table with the specified
160+
column `names` (a tuple of symbols). If `names` are not specified,
161+
`names=(:x1, :x2, ..., :xn)` is used, where `n=size(A, 2)` $REQUIRE.
162+
163+
If a `prototype` is specified, then the matrix is materialized as a table of
164+
the preferred sink type of `prototype`, rather than wrapped. Note that if
165+
`prototype` is *not* specified, then `matrix(table(A))` is essentially a no-op.
166+
"""
167+
table(X; kw...) = table(get_interface_mode(), X; kw...)
168+
169+
table(::LightInterface, X; kw...) = errlight("table")
170+
171+
# ------------------------------------------------------------------------
172+
# nrows, select, selectrows, selectcols
173+
174+
"""
175+
nrows(X)
176+
177+
Return the number of rows for a table, abstract vector or matrix `X` $REQUIRE.
178+
"""
179+
nrows(X) = nrows(get_interface_mode(), vtrait(X), X)
180+
181+
nrows(::Mode, ::Val{:other}, X::AbstractVecOrMat) = size(X, 1)
182+
183+
nrows(::Mode, ::Val{:other}, X) =
184+
throw(ArgumentError("Function `nrows` only supports AbstractVector or " *
185+
"AbstractMatrix or containers implementing the " *
186+
"Tables interface."))
187+
188+
nrows(::LightInterface, ::Val{:table}, X) = errlight("table")
189+
190+
"""
191+
selectrows(X, r)
192+
193+
Select single or multiple rows from a table, abstract vector or matrix `X`
194+
$REQUIRE. If `X` is tabular, the object returned is a table of the
195+
preferred sink type of `typeof(X)`, even if only a single row is selected.
196+
"""
197+
selectrows(X, r) = selectrows(get_interface_mode(), vtrait(X), X, r)
198+
199+
selectrows(::Mode, ::Val{:other}, ::Nothing, r) = nothing
200+
201+
selectrows(::Mode, ::Val{:other}, X::AbstractVector, r) = X[r]
202+
selectrows(::Mode, ::Val{:other}, X::AbstractVector, r::Integer) = X[r:r]
203+
selectrows(::Mode, ::Val{:other}, X::AbstractVector, ::Colon) = X
204+
205+
selectrows(::Mode, ::Val{:other}, X::AbstractMatrix, r) = X[r, :]
206+
selectrows(::Mode, ::Val{:other}, X::AbstractMatrix, r::Integer) = X[r:r, :]
207+
selectrows(::Mode, ::Val{:other}, X::AbstractMatrix, ::Colon) = X
208+
209+
selectrows(::Mode, ::Val{:other}, X, r) =
210+
throw(ArgumentError("Function `selectrows` only supports AbstractVector " *
211+
"or AbstractMatrix or containers implementing the " * "Tables interface."))
212+
213+
selectrows(::LightInterface, ::Val{:table}, X, r; kw...) =
214+
errlight("selectrows")
215+
216+
"""
217+
selectcols(X, c)
218+
219+
Select single or multiple columns from a matrix or table `X` $REQUIRE. If `c`
220+
is an abstract vector of integers or symbols, then the object returned
221+
is a table of the preferred sink type of `typeof(X)`. If `c` is a
222+
*single* integer or column, then an `AbstractVector` is returned.
223+
"""
224+
selectcols(X, c) = selectcols(get_interface_mode(), vtrait(X), X, c)
225+
226+
selectcols(::Mode, ::Val{:other}, ::Nothing, c) = nothing
227+
228+
selectcols(::Mode, ::Val{:other}, X::AbstractMatrix, r) = X[:, r]
229+
selectcols(::Mode, ::Val{:other}, X::AbstractMatrix, ::Colon) = X
230+
231+
selectcols(::Mode, ::Val{:other}, X, r) =
232+
throw(ArgumentError("Function `selectcols` only supports AbstractMatrix " *
233+
"or containers implementing the Tables interface."))
234+
235+
selectcols(::LightInterface, ::Val{:table}, X, c; kw...) =
236+
errlight("selectcols")
237+
238+
"""
239+
select(X, r, c)
240+
241+
Select element(s) of a table or matrix at row(s) `r` and column(s) `c`. In the
242+
case of sparse data where the key `(r, c)`, zero or `missing` is returned,
243+
depending on the value type. See also: [`selectrows`](@ref),
244+
[`selectcols`](@ref).
245+
"""
246+
select(X, r, c) = select(get_interface_mode(), vtrait(X), X, r, c)
247+
248+
select(::Mode, ::Val, X, r, c) = selectcols(selectrows(X, r), c)
249+
250+
# ------------------------------------------------------------------------
251+
# UnivariateFinite
252+
253+
"""
254+
UnivariateFinite(classes, p)
255+
256+
A discrete univariate distribution whose finite support is the elements of the
257+
vector `classes`, and whose corresponding probabilities are elements of the
258+
vector `p`, which must sum to one $REQUIRE.. Here `classes` must have type
259+
`AbstractVector{<:CategoricalElement}` where
260+
261+
CategoricalElement = Union{CategoricalValue,CategoricalString}
262+
263+
and all classes are assumed to share the same categorical pool.
264+
265+
UnivariateFinite(prob_given_class)
266+
267+
A discrete univariate distribution whose finite support is the set of keys of
268+
the provided dictionary, `prob_given_class` $REQUIRE.. The dictionary keys must
269+
be of type `CategoricalElement` (see above) and the dictionary values specify
270+
the corresponding probabilities.
271+
"""
272+
UnivariateFinite(d::AbstractDict) = UnivariateFinite(get_interface_mode(), d)
273+
UnivariateFinite(c::AbstractVector, p::AbstractVector) =
274+
UnivariateFinite(get_interface_mode(), c, p)
275+
276+
UnivariateFinite(::LightInterface, a...) = errlight("UnivariateFinite")

0 commit comments

Comments
 (0)