Skip to content

Commit 250b393

Browse files
authored
Refactor Learn transform in terms of new LabeledTable (#16)
* Add LabeledTable * Implement Tables.jl interface for LabeledTable * Refactor Learn transform * Update README.md * Add IO methods for LabeledTable * Remove MLJ extension * Refactor tests * Adjust label style in IO methods * Implement accessor methods for LabeledTable * Refactor Learn in terms of LabeledTable api * Bump version * Change colors in IO methods
1 parent 5a7e7ab commit 250b393

File tree

9 files changed

+261
-204
lines changed

9 files changed

+261
-204
lines changed

Project.toml

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "StatsLearnModels"
22
uuid = "c146b59d-1589-421c-8e09-a22e554fd05c"
3+
version = "1.2.0"
34
authors = ["Elias Carvalho <eliascarvdev@gmail.com> and contributors"]
4-
version = "1.1.1"
55

66
[deps]
77
ColumnSelectors = "9cc86067-7e36-4c61-b350-1ac9833d277f"
@@ -11,26 +11,23 @@ Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
1111
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1212
GLM = "38e38edf-8417-5370-95a0-9cbb8c7f171a"
1313
NearestNeighbors = "b8a86587-4115-5ab1-83bc-aa920d37bbce"
14+
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
1415
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
16+
StyledStrings = "f489334b-da3d-4c2e-b8f0-e476e12c162b"
1517
TableTransforms = "0d432bfd-3ee1-4ac1-886a-39f05cc69a3e"
1618
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
1719

18-
[weakdeps]
19-
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
20-
21-
[extensions]
22-
StatsLearnModelsMLJModelInterfaceExt = "MLJModelInterface"
23-
2420
[compat]
2521
ColumnSelectors = "1.0"
2622
DataScienceTraits = "1.0"
2723
DecisionTree = "0.12"
2824
Distances = "0.10"
2925
Distributions = "0.25"
3026
GLM = "1.9"
31-
MLJModelInterface = "1.9"
3227
NearestNeighbors = "0.4"
28+
PrettyTables = "3.0.2"
3329
StatsBase = "0.33, 0.34"
30+
StyledStrings = "1.0"
3431
TableTransforms = "1.15"
3532
Tables = "1.11"
3633
julia = "1.9"

README.md

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -26,27 +26,22 @@ train = (feature1=rand(100), feature2=rand(100), target=rand(1:2, 100))
2626
test = (feature1=rand(20), feature2=rand(20))
2727
```
2828

29-
One can train a learning `model` (e.g. `RandomForestClassifier`) with
30-
the `train` table:
29+
It is possible train a learning `model` (e.g. `RandomForestClassifier`) with
30+
the `train` table to approximate a `:target` label and perform predictions
31+
with the `test` table:
3132

3233
```julia
3334
model = RandomForestClassifier()
34-
35-
learn = Learn(train, model, ["feature1","feature2"] => "target")
36-
```
37-
38-
and apply the trained `model` to the `test` table:
39-
40-
```julia
41-
pred = learn(test)
35+
learn = Learn(label(train, :target); model)
36+
preds = learn(test)
4237
```
4338

44-
The package exports native Julia models from various packages
45-
in the ecosystem. It is also possible to use models from the
46-
[MLJ.jl](https://github.com/JuliaAI/MLJ.jl) stack.
39+
The function `label` is used to tag columns of the table with target labels,
40+
which can be categorical or continuous. Remaining columns are assumed to be
41+
predictors.
4742

48-
The combination of TableTransforms.jl with StatsLearnModels.jl
49-
can be thought of as a powerful alternative to MLJ.jl.
43+
Please check the [models](https://github.com/JuliaML/StatsLearnModels.jl/tree/main/src/models)
44+
directory for documentation on available models and their parameters.
5045

5146
[build-img]: https://img.shields.io/github/actions/workflow/status/JuliaML/StatsLearnModels.jl/CI.yml?branch=main&style=flat-square
5247
[build-url]: https://github.com/JuliaML/StatsLearnModels.jl/actions

ext/StatsLearnModelsMLJModelInterfaceExt.jl

Lines changed: 0 additions & 36 deletions
This file was deleted.

src/StatsLearnModels.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,15 @@
44

55
module StatsLearnModels
66

7-
using Tables
87
using Distances
8+
using PrettyTables
9+
using StyledStrings
910
using DataScienceTraits
1011
using StatsBase: mode, mean
1112
using ColumnSelectors: ColumnSelector, selector
1213
using TableTransforms: StatelessFeatureTransform
1314

15+
import Tables
1416
import TableTransforms: applyfeat, isrevertible
1517

1618
using DecisionTree: AdaBoostStumpClassifier, DecisionTreeClassifier, RandomForestClassifier
@@ -22,13 +24,20 @@ import GLM
2224
import DecisionTree as DT
2325
import NearestNeighbors as NN
2426

27+
include("labeledtable.jl")
2528
include("interface.jl")
2629
include("models/nn.jl")
2730
include("models/glm.jl")
2831
include("models/tree.jl")
2932
include("learn.jl")
3033

3134
export
35+
# labeled table
36+
LabeledTable,
37+
predictors,
38+
targets,
39+
label,
40+
3241
# NearestNeighbors.jl
3342
KNNClassifier,
3443
KNNRegressor,

src/interface.jl

Lines changed: 9 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -2,47 +2,21 @@
22
# Licensed under the MIT License. See LICENSE in the project root.
33
# ------------------------------------------------------------------
44

5-
"""
6-
StatsLearnModel(model, features, targets)
7-
8-
Wrap a statistical learning `model` with selectors
9-
of `features` and `targets`.
10-
11-
## Examples
12-
13-
```julia
14-
StatsLearnModel(DecisionTreeClassifier(), ["x1","x2"], "y")
15-
StatsLearnModel(DecisionTreeClassifier(), 1:3, "target")
16-
```
17-
"""
18-
struct StatsLearnModel{M,F<:ColumnSelector,T<:ColumnSelector}
19-
model::M
20-
feats::F
21-
targs::T
22-
end
23-
24-
StatsLearnModel(model, feats, targs) = StatsLearnModel(model, selector(feats), selector(targs))
25-
265
"""
276
fit(model, input, output)
287
29-
Fit statistical learning `model` using features in `input` table
30-
and targets in `output` table. Returns a fitted model with all
31-
the necessary information for prediction with the `predict` function.
8+
Fit statistical learning `model` using predictors
9+
in `input` table and targets in `output` table.
10+
Returns a fitted model with all the necessary
11+
information for prediction with [`predict`](@ref).
3212
"""
3313
function fit end
3414

35-
function Base.show(io::IO, model::StatsLearnModel{M}) where {M}
36-
println(io, "StatsLearnModel{$(nameof(M))}")
37-
println(io, "├─ features: $(model.feats)")
38-
print(io, "└─ targets: $(model.targs)")
39-
end
40-
4115
"""
4216
FittedStatsLearnModel(model, cache)
4317
44-
Wrap the statistical learning `model` with the `cache`
45-
produced during the [`fit`](@ref) stage.
18+
Wrap the statistical learning `model` with the
19+
`cache` produced during the [`fit`](@ref) stage.
4620
"""
4721
struct FittedStatsLearnModel{M,C}
4822
model::M
@@ -53,7 +27,9 @@ end
5327
predict(model::FittedStatsLearnModel, table)
5428
5529
Predict targets using the fitted statistical
56-
learning `model` and a new `table` of features.
30+
learning `model` and a new `table` containing
31+
the same predictors used during the [`fit`](@ref)
32+
stage.
5733
"""
5834
function predict end
5935

src/labeledtable.jl

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
# ------------------------------------------------------------------
2+
# Licensed under the MIT License. See LICENSE in the project root.
3+
# ------------------------------------------------------------------
4+
5+
"""
6+
LabeledTable(table, names)
7+
8+
Stores a Tables.jl `table` along with column `names` that
9+
identify which columns are labels for supervised learning.
10+
"""
11+
struct LabeledTable{T}
12+
table::T
13+
labels::Vector{Symbol}
14+
end
15+
16+
function LabeledTable(table, names)
17+
Tables.istable(table) || throw(ArgumentError("please provide a valid Tables.jl table"))
18+
cols = Tables.columns(table)
19+
vars = Tables.columnnames(cols)
20+
labs = selector(names)(vars)
21+
labs vars || throw(ArgumentError("all labels must be column names in the table"))
22+
vars labs && throw(ArgumentError("there must be at least one feature column in the table"))
23+
LabeledTable{typeof(table)}(table, labs)
24+
end
25+
26+
# -----------------
27+
# TABLES INTERFACE
28+
# -----------------
29+
30+
Tables.istable(::Type{<:LabeledTable}) = true
31+
32+
Tables.rowaccess(::Type{<:LabeledTable{T}}) where {T} = Tables.rowaccess(T)
33+
34+
Tables.columnaccess(::Type{<:LabeledTable{T}}) where {T} = Tables.columnaccess(T)
35+
36+
Tables.rows(t::LabeledTable) = Tables.rows(t.table)
37+
38+
Tables.columns(t::LabeledTable) = Tables.columns(t.table)
39+
40+
Tables.columnnames(t::LabeledTable) = Tables.columnnames(t.table)
41+
42+
# -----------------
43+
# ACCESSOR METHODS
44+
# -----------------
45+
46+
Base.parent(t::LabeledTable) = t.table
47+
48+
function predictors(t::LabeledTable)
49+
cols = Tables.columns(t.table)
50+
vars = Tables.columnnames(cols)
51+
setdiff(vars, t.labels)
52+
end
53+
54+
targets(t::LabeledTable) = t.labels
55+
56+
# -----------
57+
# IO METHODS
58+
# -----------
59+
60+
function Base.summary(io::IO, t::LabeledTable)
61+
name = nameof(typeof(t))
62+
nlab = length(t.labels)
63+
print(io, "$name with $nlab label(s)")
64+
end
65+
66+
Base.show(io::IO, t::LabeledTable) = summary(io, t)
67+
68+
function Base.show(io::IO, ::MIME"text/plain", t::LabeledTable)
69+
pretty_table(io, t; backend=:text, _common_kwargs(t)...)
70+
end
71+
72+
function Base.show(io::IO, ::MIME"text/html", t::LabeledTable)
73+
pretty_table(
74+
io,
75+
t;
76+
backend=:html,
77+
_common_kwargs(t)...,
78+
renderer=:show,
79+
style=HtmlTableStyle(title=["font-size" => "14px"])
80+
)
81+
end
82+
83+
function _common_kwargs(t)
84+
cols = Tables.columns(t)
85+
vars = Tables.columnnames(cols)
86+
87+
labels = map(vars) do var
88+
if var t.labels
89+
styled"{(weight=bold),magenta:$var}"
90+
else
91+
styled"{(weight=bold),yellow:$var}"
92+
end
93+
end
94+
95+
(
96+
title=summary(t),
97+
column_labels=collect(labels),
98+
maximum_number_of_rows=10,
99+
new_line_at_end=false,
100+
alignment=:c
101+
)
102+
end
103+
104+
"""
105+
label(table, names)
106+
107+
Creates a `LabeledTable` from `table` using `names` as label columns.
108+
"""
109+
label(table, names) = LabeledTable(table, names)

0 commit comments

Comments
 (0)