Skip to content

Commit 7e5f651

Browse files
committed
add learners/classification.jl and tests
1 parent 1d165a2 commit 7e5f651

File tree

8 files changed

+194
-4
lines changed

8 files changed

+194
-4
lines changed

Project.toml

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ authors = ["Anthony D. Blaom <[email protected]>"]
44
version = "0.2.1"
55

66
[deps]
7+
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
8+
CategoricalDistributions = "af321ab8-2d2e-40a6-b165-3d674595d28e"
79
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
810
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
911
IsURL = "ceb4388c-583f-448d-bb30-00b11e8c5682"
@@ -23,6 +25,8 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2325
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
2426

2527
[compat]
28+
CategoricalArrays = "0.10.8"
29+
CategoricalDistributions = "0.1.15"
2630
Distributions = "0.25"
2731
InteractiveUtils = "<0.0.1, 1"
2832
IsURL = "0.2.0"
@@ -47,7 +51,16 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
4751
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
4852
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
4953
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
54+
StatsModels = "3eaba693-59b7-5ba5-a881-562e759f1c8d"
5055
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
5156

5257
[targets]
53-
test = ["DataFrames", "Distributions", "Random", "LinearAlgebra", "Statistics", "Tables"]
58+
test = [
59+
"DataFrames",
60+
"Distributions",
61+
"Random",
62+
"LinearAlgebra",
63+
"Statistics",
64+
"StatsModels",
65+
"Tables",
66+
]

src/LearnTestAPI.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ include("logging.jl")
5353
include("testapi.jl")
5454
include("learners/static_algorithms.jl")
5555
include("learners/regression.jl")
56+
include("learners/classification.jl")
5657
include("learners/ensembling.jl")
5758
# next learner excluded because of heavy dependencies:
5859
# include("learners/gradient_descent.jl")

src/learners/classification.jl

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
# This file defines `ConstantClassifier()`
2+
3+
using LearnAPI
4+
import LearnDataFrontEnds as FrontEnds
5+
import MLCore
6+
import CategoricalArrays
7+
import CategoricalDistributions
8+
import CategoricalDistributions.OrderedCollections.OrderedDict
9+
import CategoricalDistributions.Distributions.StatsBase.proportionmap
10+
11+
# The implementation of a constant classifier below is not the simplest, but it
12+
# demonstrates some patterns that apply more generally in classification.
13+
14+
"""
15+
ConstantClassifier()
16+
17+
Instantiate a constant (dummy) classifier. Can predict `Point` or `Distribution` targets.
18+
19+
"""
20+
struct ConstantClassifier end
21+
22+
struct ConstantClassifierFitted
23+
learner::ConstantClassifier
24+
probabilities
25+
names::Vector{Symbol}
26+
classes_seen
27+
codes_seen
28+
decoder
29+
end
30+
31+
LearnAPI.learner(model::ConstantClassifierFitted) = model.learner
32+
33+
# add a data front end; `obs` will return objects with type `FrontEnds.Obs`:
34+
const front_end = FrontEnds.Sage(code_type=:small)
35+
LearnAPI.obs(learner::ConstantClassifier, data) =
36+
FrontEnds.fitobs(learner, data, front_end)
37+
LearnAPI.obs(model::ConstantClassifierFitted, data) =
38+
obs(model, data, front_end)
39+
40+
# data deconstructors:
41+
LearnAPI.features(learner::ConstantClassifier, data) =
42+
LearnAPI.features(learner, data, front_end)
43+
LearnAPI.target(learner::ConstantClassifier, data) =
44+
LearnAPI.target(learner, data, front_end)
45+
46+
function LearnAPI.fit(learner::ConstantClassifier, observations::FrontEnds.Obs; verbosity=1)
47+
y = observations.target # integer "codes"
48+
names = observations.names
49+
classes_seen = observations.classes_seen
50+
codes_seen = sort(unique(y))
51+
decoder = observations.decoder
52+
53+
d = proportionmap(y)
54+
# proportions ordered by key, i.e., by codes seen:
55+
probabilities = values(sort!(OrderedDict(d))) |> collect
56+
57+
return ConstantClassifierFitted(
58+
learner,
59+
probabilities,
60+
names,
61+
classes_seen,
62+
codes_seen,
63+
decoder,
64+
)
65+
end
66+
LearnAPI.fit(learner::ConstantClassifier, data; kwargs...) =
67+
fit(learner, obs(learner, data); kwargs...)
68+
69+
function LearnAPI.predict(
70+
model::ConstantClassifierFitted,
71+
::Point,
72+
observations::FrontEnds.Obs,
73+
)
74+
n = MLCore.numobs(observations)
75+
idx = argmax(model.probabilities)
76+
code_of_mode = model.codes_seen[idx]
77+
return model.decoder.(fill(code_of_mode, n))
78+
end
79+
LearnAPI.predict(model::ConstantClassifierFitted, ::Point, data) =
80+
predict(model, Point(), obs(model, data))
81+
82+
function LearnAPI.predict(
83+
model::ConstantClassifierFitted,
84+
::Distribution,
85+
observations::FrontEnds.Obs,
86+
)
87+
n = MLCore.numobs(observations)
88+
probs = model.probabilities
89+
# repeat vertically to get rows of a matrix:
90+
probs_matrix = reshape(repeat(probs, n), (length(probs), n))'
91+
return CategoricalDistributions.UnivariateFinite(model.classes_seen, probs_matrix)
92+
end
93+
LearnAPI.predict(model::ConstantClassifierFitted, ::Distribution, data) =
94+
predict(model, Distribution(), obs(model, data))
95+
96+
# accessor function:
97+
LearnAPI.feature_names(model::ConstantClassifierFitted) = model.names
98+
99+
@trait(
100+
ConstantClassifier,
101+
constructor = ConstantClassifier,
102+
kinds_of_proxy = (Point(),Distribution()),
103+
tags = ("classification",),
104+
functions = (
105+
:(LearnAPI.fit),
106+
:(LearnAPI.learner),
107+
:(LearnAPI.clone),
108+
:(LearnAPI.strip),
109+
:(LearnAPI.obs),
110+
:(LearnAPI.features),
111+
:(LearnAPI.target),
112+
:(LearnAPI.predict),
113+
:(LearnAPI.feature_names),
114+
)
115+
)
116+
117+
true

src/learners/ensembling.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ LearnAPI.components(model::EnsembleFitted) = [:atom => model.models,]
211211
# - `out_of_sample_losses`
212212

213213
# For simplicity, this implementation is restricted to univariate features. The simplistic
214-
# algorithm is explained in the docstring. of the data presented.
214+
# algorithm is explained in the docstring.
215215

216216

217217
# ## HELPERS
@@ -276,6 +276,7 @@ function update!(
276276
stump = Stump(ξ, left, right)
277277
push!(forest, stump)
278278
new_predictions = _predict(stump, x)
279+
279280
# efficient in-place update of `predictions`:
280281
predictions .= (k*predictions .+ new_predictions)/(k + 1)
281282
push!(training_losses, (predictions[training_indices] .- ytrain).^2 |> sum)

test/learners/classification.jl

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
using Test
2+
using LearnTestAPI
3+
using LearnAPI
4+
import MLCore
5+
using StableRNGs
6+
import DataFrames
7+
using Tables
8+
import CategoricalArrays
9+
import StatsModels: @formula
10+
import CategoricalDistributions.pdf
11+
12+
# # SYNTHESIZE LOTS OF DATASETS
13+
14+
n = 2
15+
rng = StableRNG(345)
16+
# has a "hidden" level, `C`:
17+
t = CategoricalArrays.categorical(repeat("ABA", 3n)*"CC" |> collect)[1:3n]
18+
c, a = randn(rng, 3n), rand(rng, 3n)
19+
y = t
20+
Y = (; t)
21+
22+
# feature matrix:
23+
x = hcat(c, a) |> permutedims
24+
25+
# feature tables:
26+
X = (; c, a)
27+
X1, X2, X3, X4, X5 = X,
28+
Tables.rowtable(X),
29+
Tables.dictrowtable(X),
30+
Tables.dictcolumntable(X),
31+
DataFrames.DataFrame(X);
32+
33+
# full tables:
34+
T = (; c, t, a)
35+
T1, T2, T3, T4, T5 = T,
36+
Tables.rowtable(T),
37+
Tables.dictrowtable(T),
38+
Tables.dictcolumntable(T),
39+
DataFrames.DataFrame(T);
40+
41+
# StatsModels.jl @formula:
42+
f = @formula(t ~ c + a)
43+
44+
45+
# # TESTS
46+
47+
learner = LearnTestAPI.ConstantClassifier()
48+
@testapi learner (X1, y)
49+
@testapi learner (X2, y) (X3, y) (X4, y) (T1, :t) (T2, :t) (T3, f) (T4, f) verbosity=0
50+
51+
@testset "extra tests for constant classifier" begin
52+
model = fit(learner, (x, y))
53+
@test predict(model, x) == fill('A', 3n)
54+
@test pdf.(predict(model, Distribution(), x), 'A') fill(2/3, 3n)
55+
end
56+
57+
true

test/learners/ensembling.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ learner = LearnTestAPI.Ensemble(atom; n=4, rng)
3030
@testset "extra tests for ensemble" begin
3131
@test LearnAPI.clone(learner) == learner
3232
@test LearnAPI.target(learner, data) == y
33-
@test LearnAPI.features(learner, data) == X
33+
@test LearnAPI.features(learner, data).features == Tables.matrix(X)'
3434

3535
model = @test_logs(
3636
(:info, r"Trained 4 ridge"),

test/learners/regression.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ data = (X, y)
2222
learner = LearnTestAPI.Ridge(lambda=0.5)
2323
@testapi learner data verbosity=1
2424

25-
@testset "extra tests for ridge regression" begin
25+
@testset "extra tests for ridge regressor" begin
2626
@test :(LearnAPI.obs) in LearnAPI.functions(learner)
2727

2828
@test LearnAPI.target(learner, data) == y

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ test_files = [
44
"tools.jl",
55
"learners/static_algorithms.jl",
66
"learners/regression.jl",
7+
"learners/classification.jl",
78
"learners/ensembling.jl",
89
# "learners/gradient_descent.jl",
910
"learners/incremental_algorithms.jl",

0 commit comments

Comments
 (0)