1
- # THIS FILE IS NOT INCLUDED
1
+ # THIS FILE IS NOT INCLUDED BUT EXISTS AS AN IMPLEMENTATION EXEMPLAR
2
2
3
3
# This file defines:
4
-
5
4
# - `PerceptronClassifier(; epochs=50, optimiser=Optimisers.Adam(), rng=Random.default_rng())
5
+
6
6
using LearnAPI
7
7
using Random
8
8
using Statistics
49
49
"""
50
50
corefit(perceptron, optimiser, X, y_hot, epochs, state, verbosity)
51
51
52
- Return updated `perceptron`, `state` and training losses by carrying out gradient descent
52
+ Return updated `perceptron`, `state`, and training losses by carrying out gradient descent
53
53
for the specified number of `epochs`.
54
54
55
55
- `perceptron`: component array with components `weights` and `bias`
@@ -108,13 +108,7 @@ point predictions with `predict(model, Point(), Xnew)`.
108
108
109
109
# Warm restart options
110
110
111
- update_observations(model, newdata; replacements...)
112
-
113
- Return an updated model, with the weights and bias of the previously learned perceptron
114
- used as the starting state in new gradient descent updates. Adopt any specified
115
- hyperparameter `replacements` (properties of `LearnAPI.learner(model)`).
116
-
117
- update(model, newdata; epochs=n, replacements...)
111
+ update(model, newdata, :epochs=>n, other_replacements...; verbosity=1)
118
112
119
113
If `Δepochs = n - perceptron.epochs` is non-negative, then return an updated model, with
120
114
the weights and bias of the previously learned perceptron used as the starting state in
@@ -123,17 +117,18 @@ instead of the previous training data. Any other hyperparaameter `replacements`
123
117
adopted. If `Δepochs` is negative or not specified, instead return `fit(learner,
124
118
newdata)`, where `learner=LearnAPI.clone(learner; epochs=n, replacements....)`.
125
119
120
+ update_observations(model, newdata, replacements...; verbosity=1)
121
+
122
+ Return an updated model, with the weights and bias of the previously learned perceptron
123
+ used as the starting state in new gradient descent updates. Adopt any specified
124
+ hyperparameter `replacements` (properties of `LearnAPI.learner(model)`).
125
+
126
126
"""
127
127
PerceptronClassifier (; epochs= 50 , optimiser= Optimisers. Adam (), rng= Random. default_rng ()) =
128
128
PerceptronClassifier (epochs, optimiser, rng)
129
129
130
130
131
- # ### Data interface
132
-
133
- # For raw training data:
134
- LearnAPI. target (learner:: PerceptronClassifier , data:: Tuple ) = last (data)
135
-
136
- # For wrapping pre-processed training data (output of `obs(learner, data)`):
131
+ # Type for internal representation of data (output of `obs(learner, data)`):
137
132
struct PerceptronClassifierObs
138
133
X:: Matrix{Float32}
139
134
y_hot:: BitMatrix # one-hot encoded target
@@ -164,15 +159,19 @@ Base.getindex(observations::PerceptronClassifierObs, I) = PerceptronClassifierOb
164
159
observations. classes,
165
160
)
166
161
162
+ # training data deconstructors:
167
163
LearnAPI. target (
168
164
learner:: PerceptronClassifier ,
169
165
observations:: PerceptronClassifierObs ,
170
166
) = decode (observations. y_hot, observations. classes)
171
-
167
+ LearnAPI. target (learner:: PerceptronClassifier , data) =
168
+ LearnAPI. target (learner, obs (learner, data))
172
169
LearnAPI. features (
173
170
learner:: PerceptronClassifier ,
174
171
observations:: PerceptronClassifierObs ,
175
172
) = observations. X
173
+ LearnAPI. features (learner:: PerceptronClassifier , data) =
174
+ LearnAPI. features (learner, obs (learner, data))
176
175
177
176
# Note that data consumed by `predict` needs no pre-processing, so no need to overload
178
177
# `obs(model, data)`.
@@ -229,9 +228,9 @@ LearnAPI.fit(learner::PerceptronClassifier, data; kwargs...) =
229
228
# see the `PerceptronClassifier` docstring for `update_observations` logic.
230
229
function LearnAPI. update_observations (
231
230
model:: PerceptronClassifierFitted ,
232
- observations_new:: PerceptronClassifierObs ;
231
+ observations_new:: PerceptronClassifierObs ,
232
+ replacements... ;
233
233
verbosity= 1 ,
234
- replacements... ,
235
234
)
236
235
237
236
# unpack data:
@@ -243,7 +242,7 @@ function LearnAPI.update_observations(
243
242
classes == model. classes || error (" New training target has incompatible classes." )
244
243
245
244
learner_old = LearnAPI. learner (model)
246
- learner = LearnAPI. clone (learner_old; replacements... )
245
+ learner = LearnAPI. clone (learner_old, replacements... )
247
246
248
247
perceptron = model. perceptron
249
248
state = model. state
@@ -255,15 +254,15 @@ function LearnAPI.update_observations(
255
254
256
255
return PerceptronClassifierFitted (learner, perceptron, state, classes, losses)
257
256
end
258
- LearnAPI. update_observations (model:: PerceptronClassifierFitted , data; kwargs... ) =
259
- update_observations (model, obs (LearnAPI. learner (model), data); kwargs... )
257
+ LearnAPI. update_observations (model:: PerceptronClassifierFitted , data, args ... ; kwargs... ) =
258
+ update_observations (model, obs (LearnAPI. learner (model), data), args ... ; kwargs... )
260
259
261
260
# see the `PerceptronClassifier` docstring for `update` logic.
262
261
function LearnAPI. update (
263
262
model:: PerceptronClassifierFitted ,
264
- observations:: PerceptronClassifierObs ;
263
+ observations:: PerceptronClassifierObs ,
264
+ replacements... ;
265
265
verbosity= 1 ,
266
- replacements... ,
267
266
)
268
267
269
268
# unpack data:
@@ -275,24 +274,25 @@ function LearnAPI.update(
275
274
classes == model. classes || error (" New training target has incompatible classes." )
276
275
277
276
learner_old = LearnAPI. learner (model)
278
- learner = LearnAPI. clone (learner_old; replacements... )
279
- :epochs in keys (replacements) || return fit (learner, observations)
277
+ learner = LearnAPI. clone (learner_old, replacements... )
278
+ :epochs in keys (replacements) || return fit (learner, observations; verbosity )
280
279
281
280
perceptron = model. perceptron
282
281
state = model. state
283
282
losses = model. losses
284
283
285
284
epochs = learner. epochs
286
285
Δepochs = epochs - learner_old. epochs
287
- epochs < 0 && return fit (model, learner)
286
+ epochs < 0 && return fit (model, learner; verbosity )
288
287
289
- perceptron, state, losses_new = corefit (perceptron, X, y_hot, Δepochs, state, verbosity)
288
+ perceptron, state, losses_new =
289
+ corefit (perceptron, X, y_hot, Δepochs, state, verbosity)
290
290
losses = vcat (losses, losses_new)
291
291
292
292
return PerceptronClassifierFitted (learner, perceptron, state, classes, losses)
293
293
end
294
- LearnAPI. update (model:: PerceptronClassifierFitted , data; kwargs... ) =
295
- update (model, obs (LearnAPI. learner (model), data); kwargs... )
294
+ LearnAPI. update (model:: PerceptronClassifierFitted , data, args ... ; kwargs... ) =
295
+ update (model, obs (LearnAPI. learner (model), data), args ... ; kwargs... )
296
296
297
297
298
298
# ### Predict
@@ -335,13 +335,3 @@ LearnAPI.training_losses(model::PerceptronClassifierFitted) = model.losses
335
335
:(LearnAPI. training_losses),
336
336
)
337
337
)
338
-
339
-
340
- # ### Convenience methods
341
-
342
- LearnAPI. fit (learner:: PerceptronClassifier , X, y; kwargs... ) =
343
- fit (learner, (X, y); kwargs... )
344
- LearnAPI. update_observations (learner:: PerceptronClassifier , X, y; kwargs... ) =
345
- update_observations (learner, (X, y); kwargs... )
346
- LearnAPI. update (learner:: PerceptronClassifier , X, y; kwargs... ) =
347
- update (learner, (X, y); kwargs... )
0 commit comments