1- # THIS FILE IS NOT INCLUDED
1+ # THIS FILE IS NOT INCLUDED BUT EXISTS AS AN IMPLEMENTATION EXEMPLAR
22
33# This file defines:
4-
54# - `PerceptronClassifier(; epochs=50, optimiser=Optimisers.Adam(), rng=Random.default_rng())
5+
66using LearnAPI
77using Random
88using Statistics
4949"""
5050 corefit(perceptron, optimiser, X, y_hot, epochs, state, verbosity)
5151
52- Return updated `perceptron`, `state` and training losses by carrying out gradient descent
52+ Return updated `perceptron`, `state`, and training losses by carrying out gradient descent
5353for the specified number of `epochs`.
5454
5555- `perceptron`: component array with components `weights` and `bias`
@@ -108,13 +108,7 @@ point predictions with `predict(model, Point(), Xnew)`.
108108
109109# Warm restart options
110110
111- update_observations(model, newdata; replacements...)
112-
113- Return an updated model, with the weights and bias of the previously learned perceptron
114- used as the starting state in new gradient descent updates. Adopt any specified
115- hyperparameter `replacements` (properties of `LearnAPI.learner(model)`).
116-
117- update(model, newdata; epochs=n, replacements...)
111+ update(model, newdata, :epochs=>n, other_replacements...; verbosity=1)
118112
119113If `Δepochs = n - perceptron.epochs` is non-negative, then return an updated model, with
120114the weights and bias of the previously learned perceptron used as the starting state in
@@ -123,17 +117,18 @@ instead of the previous training data. Any other hyperparaameter `replacements`
123117adopted. If `Δepochs` is negative or not specified, instead return `fit(learner,
124118newdata)`, where `learner=LearnAPI.clone(learner; epochs=n, replacements....)`.
125119
120+ update_observations(model, newdata, replacements...; verbosity=1)
121+
122+ Return an updated model, with the weights and bias of the previously learned perceptron
123+ used as the starting state in new gradient descent updates. Adopt any specified
124+ hyperparameter `replacements` (properties of `LearnAPI.learner(model)`).
125+
126126"""
127127PerceptronClassifier (; epochs= 50 , optimiser= Optimisers. Adam (), rng= Random. default_rng ()) =
128128 PerceptronClassifier (epochs, optimiser, rng)
129129
130130
131- # ### Data interface
132-
133- # For raw training data:
134- LearnAPI. target (learner:: PerceptronClassifier , data:: Tuple ) = last (data)
135-
136- # For wrapping pre-processed training data (output of `obs(learner, data)`):
131+ # Type for internal representation of data (output of `obs(learner, data)`):
137132struct PerceptronClassifierObs
138133 X:: Matrix{Float32}
139134 y_hot:: BitMatrix # one-hot encoded target
@@ -164,15 +159,19 @@ Base.getindex(observations::PerceptronClassifierObs, I) = PerceptronClassifierOb
164159 observations. classes,
165160)
166161
162+ # training data deconstructors:
167163LearnAPI. target (
168164 learner:: PerceptronClassifier ,
169165 observations:: PerceptronClassifierObs ,
170166) = decode (observations. y_hot, observations. classes)
171-
167+ LearnAPI. target (learner:: PerceptronClassifier , data) =
168+ LearnAPI. target (learner, obs (learner, data))
172169LearnAPI. features (
173170 learner:: PerceptronClassifier ,
174171 observations:: PerceptronClassifierObs ,
175172) = observations. X
173+ LearnAPI. features (learner:: PerceptronClassifier , data) =
174+ LearnAPI. features (learner, obs (learner, data))
176175
177176# Note that data consumed by `predict` needs no pre-processing, so no need to overload
178177# `obs(model, data)`.
@@ -229,9 +228,9 @@ LearnAPI.fit(learner::PerceptronClassifier, data; kwargs...) =
229228# see the `PerceptronClassifier` docstring for `update_observations` logic.
230229function LearnAPI. update_observations (
231230 model:: PerceptronClassifierFitted ,
232- observations_new:: PerceptronClassifierObs ;
231+ observations_new:: PerceptronClassifierObs ,
232+ replacements... ;
233233 verbosity= 1 ,
234- replacements... ,
235234 )
236235
237236 # unpack data:
@@ -243,7 +242,7 @@ function LearnAPI.update_observations(
243242 classes == model. classes || error (" New training target has incompatible classes." )
244243
245244 learner_old = LearnAPI. learner (model)
246- learner = LearnAPI. clone (learner_old; replacements... )
245+ learner = LearnAPI. clone (learner_old, replacements... )
247246
248247 perceptron = model. perceptron
249248 state = model. state
@@ -255,15 +254,15 @@ function LearnAPI.update_observations(
255254
256255 return PerceptronClassifierFitted (learner, perceptron, state, classes, losses)
257256end
258- LearnAPI. update_observations (model:: PerceptronClassifierFitted , data; kwargs... ) =
259- update_observations (model, obs (LearnAPI. learner (model), data); kwargs... )
257+ LearnAPI. update_observations (model:: PerceptronClassifierFitted , data, args ... ; kwargs... ) =
258+ update_observations (model, obs (LearnAPI. learner (model), data), args ... ; kwargs... )
260259
261260# see the `PerceptronClassifier` docstring for `update` logic.
262261function LearnAPI. update (
263262 model:: PerceptronClassifierFitted ,
264- observations:: PerceptronClassifierObs ;
263+ observations:: PerceptronClassifierObs ,
264+ replacements... ;
265265 verbosity= 1 ,
266- replacements... ,
267266 )
268267
269268 # unpack data:
@@ -275,24 +274,25 @@ function LearnAPI.update(
275274 classes == model. classes || error (" New training target has incompatible classes." )
276275
277276 learner_old = LearnAPI. learner (model)
278- learner = LearnAPI. clone (learner_old; replacements... )
279- :epochs in keys (replacements) || return fit (learner, observations)
277+ learner = LearnAPI. clone (learner_old, replacements... )
278+ :epochs in keys (replacements) || return fit (learner, observations; verbosity )
280279
281280 perceptron = model. perceptron
282281 state = model. state
283282 losses = model. losses
284283
285284 epochs = learner. epochs
286285 Δepochs = epochs - learner_old. epochs
287- epochs < 0 && return fit (model, learner)
286+ epochs < 0 && return fit (model, learner; verbosity )
288287
289- perceptron, state, losses_new = corefit (perceptron, X, y_hot, Δepochs, state, verbosity)
288+ perceptron, state, losses_new =
289+ corefit (perceptron, X, y_hot, Δepochs, state, verbosity)
290290 losses = vcat (losses, losses_new)
291291
292292 return PerceptronClassifierFitted (learner, perceptron, state, classes, losses)
293293end
294- LearnAPI. update (model:: PerceptronClassifierFitted , data; kwargs... ) =
295- update (model, obs (LearnAPI. learner (model), data); kwargs... )
294+ LearnAPI. update (model:: PerceptronClassifierFitted , data, args ... ; kwargs... ) =
295+ update (model, obs (LearnAPI. learner (model), data), args ... ; kwargs... )
296296
297297
298298# ### Predict
@@ -335,13 +335,3 @@ LearnAPI.training_losses(model::PerceptronClassifierFitted) = model.losses
335335 :(LearnAPI. training_losses),
336336 )
337337)
338-
339-
340- # ### Convenience methods
341-
342- LearnAPI. fit (learner:: PerceptronClassifier , X, y; kwargs... ) =
343- fit (learner, (X, y); kwargs... )
344- LearnAPI. update_observations (learner:: PerceptronClassifier , X, y; kwargs... ) =
345- update_observations (learner, (X, y); kwargs... )
346- LearnAPI. update (learner:: PerceptronClassifier , X, y; kwargs... ) =
347- update (learner, (X, y); kwargs... )
0 commit comments