Skip to content

Commit 4b7f708

Browse files
committed
update perceptron classifier: param replacement API change
1 parent aed8ef0 commit 4b7f708

File tree

3 files changed

+47
-57
lines changed

3 files changed

+47
-57
lines changed

src/learners/gradient_descent.jl

Lines changed: 30 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
# THIS FILE IS NOT INCLUDED
1+
# THIS FILE IS NOT INCLUDED BUT EXISTS AS AN IMPLEMENTATION EXEMPLAR
22

33
# This file defines:
4-
54
# - `PerceptronClassifier(; epochs=50, optimiser=Optimisers.Adam(), rng=Random.default_rng())
5+
66
using LearnAPI
77
using Random
88
using Statistics
@@ -49,7 +49,7 @@ end
4949
"""
5050
corefit(perceptron, optimiser, X, y_hot, epochs, state, verbosity)
5151
52-
Return updated `perceptron`, `state` and training losses by carrying out gradient descent
52+
Return updated `perceptron`, `state`, and training losses by carrying out gradient descent
5353
for the specified number of `epochs`.
5454
5555
- `perceptron`: component array with components `weights` and `bias`
@@ -108,13 +108,7 @@ point predictions with `predict(model, Point(), Xnew)`.
108108
109109
# Warm restart options
110110
111-
update_observations(model, newdata; replacements...)
112-
113-
Return an updated model, with the weights and bias of the previously learned perceptron
114-
used as the starting state in new gradient descent updates. Adopt any specified
115-
hyperparameter `replacements` (properties of `LearnAPI.learner(model)`).
116-
117-
update(model, newdata; epochs=n, replacements...)
111+
update(model, newdata, :epochs=>n, other_replacements...; verbosity=1)
118112
119113
If `Δepochs = n - perceptron.epochs` is non-negative, then return an updated model, with
120114
the weights and bias of the previously learned perceptron used as the starting state in
@@ -123,17 +117,18 @@ instead of the previous training data. Any other hyperparaameter `replacements`
123117
adopted. If `Δepochs` is negative or not specified, instead return `fit(learner,
124118
newdata)`, where `learner=LearnAPI.clone(learner; epochs=n, replacements....)`.
125119
120+
update_observations(model, newdata, replacements...; verbosity=1)
121+
122+
Return an updated model, with the weights and bias of the previously learned perceptron
123+
used as the starting state in new gradient descent updates. Adopt any specified
124+
hyperparameter `replacements` (properties of `LearnAPI.learner(model)`).
125+
126126
"""
127127
PerceptronClassifier(; epochs=50, optimiser=Optimisers.Adam(), rng=Random.default_rng()) =
128128
PerceptronClassifier(epochs, optimiser, rng)
129129

130130

131-
# ### Data interface
132-
133-
# For raw training data:
134-
LearnAPI.target(learner::PerceptronClassifier, data::Tuple) = last(data)
135-
136-
# For wrapping pre-processed training data (output of `obs(learner, data)`):
131+
# Type for internal representation of data (output of `obs(learner, data)`):
137132
struct PerceptronClassifierObs
138133
X::Matrix{Float32}
139134
y_hot::BitMatrix # one-hot encoded target
@@ -164,15 +159,19 @@ Base.getindex(observations::PerceptronClassifierObs, I) = PerceptronClassifierOb
164159
observations.classes,
165160
)
166161

162+
# training data deconstructors:
167163
LearnAPI.target(
168164
learner::PerceptronClassifier,
169165
observations::PerceptronClassifierObs,
170166
) = decode(observations.y_hot, observations.classes)
171-
167+
LearnAPI.target(learner::PerceptronClassifier, data) =
168+
LearnAPI.target(learner, obs(learner, data))
172169
LearnAPI.features(
173170
learner::PerceptronClassifier,
174171
observations::PerceptronClassifierObs,
175172
) = observations.X
173+
LearnAPI.features(learner::PerceptronClassifier, data) =
174+
LearnAPI.features(learner, obs(learner, data))
176175

177176
# Note that data consumed by `predict` needs no pre-processing, so no need to overload
178177
# `obs(model, data)`.
@@ -229,9 +228,9 @@ LearnAPI.fit(learner::PerceptronClassifier, data; kwargs...) =
229228
# see the `PerceptronClassifier` docstring for `update_observations` logic.
230229
function LearnAPI.update_observations(
231230
model::PerceptronClassifierFitted,
232-
observations_new::PerceptronClassifierObs;
231+
observations_new::PerceptronClassifierObs,
232+
replacements...;
233233
verbosity=1,
234-
replacements...,
235234
)
236235

237236
# unpack data:
@@ -243,7 +242,7 @@ function LearnAPI.update_observations(
243242
classes == model.classes || error("New training target has incompatible classes.")
244243

245244
learner_old = LearnAPI.learner(model)
246-
learner = LearnAPI.clone(learner_old; replacements...)
245+
learner = LearnAPI.clone(learner_old, replacements...)
247246

248247
perceptron = model.perceptron
249248
state = model.state
@@ -255,15 +254,15 @@ function LearnAPI.update_observations(
255254

256255
return PerceptronClassifierFitted(learner, perceptron, state, classes, losses)
257256
end
258-
LearnAPI.update_observations(model::PerceptronClassifierFitted, data; kwargs...) =
259-
update_observations(model, obs(LearnAPI.learner(model), data); kwargs...)
257+
LearnAPI.update_observations(model::PerceptronClassifierFitted, data, args...; kwargs...) =
258+
update_observations(model, obs(LearnAPI.learner(model), data), args...; kwargs...)
260259

261260
# see the `PerceptronClassifier` docstring for `update` logic.
262261
function LearnAPI.update(
263262
model::PerceptronClassifierFitted,
264-
observations::PerceptronClassifierObs;
263+
observations::PerceptronClassifierObs,
264+
replacements...;
265265
verbosity=1,
266-
replacements...,
267266
)
268267

269268
# unpack data:
@@ -275,24 +274,25 @@ function LearnAPI.update(
275274
classes == model.classes || error("New training target has incompatible classes.")
276275

277276
learner_old = LearnAPI.learner(model)
278-
learner = LearnAPI.clone(learner_old; replacements...)
279-
:epochs in keys(replacements) || return fit(learner, observations)
277+
learner = LearnAPI.clone(learner_old, replacements...)
278+
:epochs in keys(replacements) || return fit(learner, observations; verbosity)
280279

281280
perceptron = model.perceptron
282281
state = model.state
283282
losses = model.losses
284283

285284
epochs = learner.epochs
286285
Δepochs = epochs - learner_old.epochs
287-
epochs < 0 && return fit(model, learner)
286+
epochs < 0 && return fit(model, learner; verbosity)
288287

289-
perceptron, state, losses_new = corefit(perceptron, X, y_hot, Δepochs, state, verbosity)
288+
perceptron, state, losses_new =
289+
corefit(perceptron, X, y_hot, Δepochs, state, verbosity)
290290
losses = vcat(losses, losses_new)
291291

292292
return PerceptronClassifierFitted(learner, perceptron, state, classes, losses)
293293
end
294-
LearnAPI.update(model::PerceptronClassifierFitted, data; kwargs...) =
295-
update(model, obs(LearnAPI.learner(model), data); kwargs...)
294+
LearnAPI.update(model::PerceptronClassifierFitted, data, args...; kwargs...) =
295+
update(model, obs(LearnAPI.learner(model), data), args...; kwargs...)
296296

297297

298298
# ### Predict
@@ -335,13 +335,3 @@ LearnAPI.training_losses(model::PerceptronClassifierFitted) = model.losses
335335
:(LearnAPI.training_losses),
336336
)
337337
)
338-
339-
340-
# ### Convenience methods
341-
342-
LearnAPI.fit(learner::PerceptronClassifier, X, y; kwargs...) =
343-
fit(learner, (X, y); kwargs...)
344-
LearnAPI.update_observations(learner::PerceptronClassifier, X, y; kwargs...) =
345-
update_observations(learner, (X, y); kwargs...)
346-
LearnAPI.update(learner::PerceptronClassifier, X, y; kwargs...) =
347-
update(learner, (X, y); kwargs...)

src/testapi.jl

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -163,12 +163,12 @@ macro testapi(learner, data...)
163163
if _is_static
164164
model =
165165
@logged_testset $FIT_IS_STATIC verbosity begin
166-
LearnAPI.fit(learner; verbosity=verbosity-1)
166+
LearnAPI.fit(learner; verbosity=verbosity - 1)
167167
end
168168
else
169169
model =
170170
@logged_testset $FIT_IS_NOT_STATIC verbosity begin
171-
LearnAPI.fit(learner, data; verbosity=verbosity-1)
171+
LearnAPI.fit(learner, data; verbosity=verbosity - 1)
172172
end
173173
end
174174

@@ -342,7 +342,7 @@ macro testapi(learner, data...)
342342
@logged_testset $SELECTED_FOR_FIT verbosity begin
343343
data3 = LearnTestAPI.learner_get(learner, data)
344344
if _data_interface isa LearnAPI.RandomAccess
345-
LearnAPI.fit(learner, data3; verbosity=verbosity-1)
345+
LearnAPI.fit(learner, data3; verbosity=verbosity - 1)
346346
else
347347
nothing
348348
end
@@ -493,14 +493,19 @@ macro testapi(learner, data...)
493493
if :(LearnAPI.update) in _functions
494494
_is_static && throw($ERR_STATIC_UPDATE)
495495
@logged_testset $UPDATE verbosity begin
496-
LearnAPI.update(model, data; verbosity=0)
496+
LearnAPI.update(model, data; verbosity=verbosity - 1)
497497
end
498498
# only test hyperparameter replacement in case of iteration parameter:
499499
iter = LearnAPI.iteration_parameter(learner)
500500
if !isnothing(iter)
501501
@logged_testset $UPDATE_ITERATIONS verbosity begin
502502
n = getproperty(learner, iter)
503-
newmodel = LearnAPI.update(model, data, iter=>n+1; verbosity=0)
503+
newmodel = LearnAPI.update(
504+
model,
505+
data,
506+
iter=>n+1;
507+
verbosity=verbosity - 1,
508+
)
504509
newlearner = LearnAPI.clone(learner, iter=>n+1)
505510
Test.@test LearnAPI.learner(newmodel) == newlearner
506511
abinitiomodel = LearnAPI.fit(newlearner, data; verbosity=0)

test/learners/gradient_descent.jl

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# THIS FILE IS NOT INCLUDED BY /test/runtests.jl because of heavy dependencies. The
22
# source file, "/src/learners/gradient_descent.jl" is not included in the package, but
33
# exits as a learner exemplar. Next line manually loads the source:
4-
include(joinpath(@__DIR__, "..", "..", "src", "learners", "gradient_descent.jl")
4+
include(joinpath(@__DIR__, "..", "..", "src", "learners", "gradient_descent.jl"))
55

66
using Test
77
using LearnAPI
@@ -42,15 +42,10 @@ rng = StableRNG(123)
4242
learner =
4343
PerceptronClassifier(; optimiser=Optimisers.Adam(0.01), epochs=40, rng)
4444

45-
@testapi learner (X, y) verbosity=1
45+
@testapi learner (X, y) verbosity=0 # use verbosity=1 to debug
4646

47-
@testset "PerceptronClassfier" begin
48-
@test LearnAPI.clone(learner) == learner
49-
@test :(LearnAPI.update) in LearnAPI.functions(learner)
50-
@test LearnAPI.target(learner, (X, y)) == y
51-
@test LearnAPI.features(learner, (X, y)) == X
52-
53-
model40 = fit(learner, Xtrain, ytrain; verbosity=0)
47+
@testset "extra tests for perceptron classfier" begin
48+
model40 = fit(learner, (Xtrain, ytrain); verbosity=0)
5449

5550
# 40 epochs is sufficient for 90% accuracy in this case:
5651
@test sum(predict(model40, Point(), Xtest) .== ytest)/length(ytest) > 0.9
@@ -60,16 +55,16 @@ learner =
6055
@test predict(model40, Xtest) ŷ40
6156

6257
# add 30 epochs in an `update`:
63-
model70 = update(model40, Xtrain, y[train]; verbosity=0, epochs=70)
58+
model70 = update(model40, (Xtrain, y[train]), :epochs=>70; verbosity=0)
6459
ŷ70 = predict(model70, Xtest);
6560
@test !(ŷ70 ŷ40)
6661

6762
# compare with cold restart:
68-
model = fit(LearnAPI.clone(learner; epochs=70), Xtrain, y[train]; verbosity=0);
63+
model = fit(LearnAPI.clone(learner; epochs=70), (Xtrain, y[train]); verbosity=0);
6964
@test ŷ70 predict(model, Xtest)
7065

7166
# instead add 30 epochs using `update_observations` instead:
72-
model70b = update_observations(model40, Xtrain, y[train]; verbosity=0, epochs=30)
67+
model70b = update_observations(model40, (Xtrain, y[train]), :epochs=>30; verbosity=0)
7368
@test ŷ70 predict(model70b, Xtest) predict(model, Xtest)
7469
end
7570

0 commit comments

Comments
 (0)