11# THIS FILE IS NOT INCLUDED BUT EXISTS AS AN IMPLEMENTATION EXEMPLAR
22
3- # TODO : This file should be updated after release of CategoricalDistributions 0.2 as
4- # `classes` is deprecated there.
5-
63# This file defines:
74# - `PerceptronClassifier(; epochs=50, optimiser=Optimisers.Adam(), rng=Random.default_rng())
85
@@ -13,6 +10,7 @@ using StableRNGs
1310import Optimisers
1411import Zygote
1512import NNlib
13+ import CategoricalArrays
1614import CategoricalDistributions
1715import CategoricalDistributions: pdf, mode
1816import ComponentArrays
@@ -58,7 +56,7 @@ for the specified number of `epochs`.
5856- `perceptron`: component array with components `weights` and `bias`
5957- `optimiser`: optimiser from Optimiser.jl
6058- `X`: feature matrix, of size `(p, n)`
61- - `y_hot`: one-hot encoded target, of size `(nclasses , n)`
59+ - `y_hot`: one-hot encoded target, of size `(nlevels , n)`
6260- `epochs`: number of epochs
6361- `state`: optimiser state
6462
@@ -135,38 +133,38 @@ PerceptronClassifier(; epochs=50, optimiser=Optimisers.Adam(), rng=Random.defaul
135133struct PerceptronClassifierObs
136134 X:: Matrix{Float32}
137135 y_hot:: BitMatrix # one-hot encoded target
138- classes # the (ordered) pool of `y`, as `CategoricalValue`s
136+ levels # the (ordered) pool of `y`, as `CategoricalValue`s
139137end
140138
141139# For pre-processing the training data:
142140function LearnAPI. obs (:: PerceptronClassifier , data:: Tuple )
143141 X, y = data
144- classes = CategoricalDistributions . classes (y)
145- y_hot = classes .== permutedims (y) # one-hot encoding
146- return PerceptronClassifierObs (X, y_hot, classes )
142+ levels = CategoricalArrays . levels (y)
143+ y_hot = levels .== permutedims (y) # one-hot encoding
144+ return PerceptronClassifierObs (X, y_hot, levels )
147145end
148146LearnAPI. obs (:: PerceptronClassifier , observations:: PerceptronClassifierObs ) =
149147 observations # involutivity
150148
151149# helper:
152- function decode (y_hot, classes )
150+ function decode (y_hot, levels )
153151 n = size (y_hot, 2 )
154- [only (classes [y_hot[:,i]]) for i in 1 : n]
152+ [only (levels [y_hot[:,i]]) for i in 1 : n]
155153end
156154
157155# implement `RadomAccess()` interface for output of `obs`:
158156Base. length (observations:: PerceptronClassifierObs ) = size (observations. y_hot, 2 )
159157Base. getindex (observations:: PerceptronClassifierObs , I) = PerceptronClassifierObs (
160158 observations. X[:, I],
161159 observations. y_hot[:, I],
162- observations. classes ,
160+ observations. levels ,
163161)
164162
165163# training data deconstructors:
166164LearnAPI. target (
167165 learner:: PerceptronClassifier ,
168166 observations:: PerceptronClassifierObs ,
169- ) = decode (observations. y_hot, observations. classes )
167+ ) = decode (observations. y_hot, observations. levels )
170168LearnAPI. target (learner:: PerceptronClassifier , data) =
171169 LearnAPI. target (learner, obs (learner, data))
172170LearnAPI. features (
@@ -187,7 +185,7 @@ struct PerceptronClassifierFitted
187185 learner:: PerceptronClassifier
188186 perceptron # component array storing weights and bias
189187 state # optimiser state
190- classes # target classes
188+ levels # target levels
191189 losses
192190end
193191
@@ -208,20 +206,20 @@ function LearnAPI.fit(
208206 # unpack data:
209207 X = observations. X
210208 y_hot = observations. y_hot
211- classes = observations. classes
212- nclasses = length (classes )
209+ levels = observations. levels
210+ nlevels = length (levels )
213211
214212 # initialize bias and weights:
215- weights = randn (rng, Float32, nclasses , p)
216- bias = zeros (Float32, nclasses )
213+ weights = randn (rng, Float32, nlevels , p)
214+ bias = zeros (Float32, nlevels )
217215 perceptron = (; weights, bias) |> ComponentArrays. ComponentArray
218216
219217 # initialize optimiser:
220218 state = Optimisers. setup (optimiser, perceptron)
221219
222220 perceptron, state, losses = corefit (perceptron, X, y_hot, epochs, state, verbosity)
223221
224- return PerceptronClassifierFitted (learner, perceptron, state, classes , losses)
222+ return PerceptronClassifierFitted (learner, perceptron, state, levels , losses)
225223end
226224
227225# `fit` for unprocessed data:
@@ -239,10 +237,10 @@ function LearnAPI.update_observations(
239237 # unpack data:
240238 X = observations_new. X
241239 y_hot = observations_new. y_hot
242- classes = observations_new. classes
243- nclasses = length (classes )
240+ levels = observations_new. levels
241+ nlevels = length (levels )
244242
245- classes == model. classes || error (" New training target has incompatible classes ." )
243+ levels == model. levels || error (" New training target has incompatible levels ." )
246244
247245 learner_old = LearnAPI. learner (model)
248246 learner = LearnAPI. clone (learner_old, replacements... )
@@ -255,7 +253,7 @@ function LearnAPI.update_observations(
255253 perceptron, state, losses_new = corefit (perceptron, X, y_hot, epochs, state, verbosity)
256254 losses = vcat (losses, losses_new)
257255
258- return PerceptronClassifierFitted (learner, perceptron, state, classes , losses)
256+ return PerceptronClassifierFitted (learner, perceptron, state, levels , losses)
259257end
260258LearnAPI. update_observations (model:: PerceptronClassifierFitted , data, args... ; kwargs... ) =
261259 update_observations (model, obs (LearnAPI. learner (model), data), args... ; kwargs... )
@@ -271,10 +269,10 @@ function LearnAPI.update(
271269 # unpack data:
272270 X = observations. X
273271 y_hot = observations. y_hot
274- classes = observations. classes
275- nclasses = length (classes )
272+ levels = observations. levels
273+ nlevels = length (levels )
276274
277- classes == model. classes || error (" New training target has incompatible classes ." )
275+ levels == model. levels || error (" New training target has incompatible levels ." )
278276
279277 learner_old = LearnAPI. learner (model)
280278 learner = LearnAPI. clone (learner_old, replacements... )
@@ -292,7 +290,7 @@ function LearnAPI.update(
292290 corefit (perceptron, X, y_hot, Δepochs, state, verbosity)
293291 losses = vcat (losses, losses_new)
294292
295- return PerceptronClassifierFitted (learner, perceptron, state, classes , losses)
293+ return PerceptronClassifierFitted (learner, perceptron, state, levels , losses)
296294end
297295LearnAPI. update (model:: PerceptronClassifierFitted , data, args... ; kwargs... ) =
298296 update (model, obs (LearnAPI. learner (model), data), args... ; kwargs... )
@@ -302,9 +300,9 @@ LearnAPI.update(model::PerceptronClassifierFitted, data, args...; kwargs...) =
302300
303301function LearnAPI. predict (model:: PerceptronClassifierFitted , :: Distribution , Xnew)
304302 perceptron = model. perceptron
305- classes = model. classes
303+ levels = model. levels
306304 probs = perceptron. weights* Xnew .+ perceptron. bias |> NNlib. softmax
307- return CategoricalDistributions. UnivariateFinite (classes , probs' )
305+ return CategoricalDistributions. UnivariateFinite (levels , probs' )
308306end
309307
310308LearnAPI. predict (model:: PerceptronClassifierFitted , :: Point , Xnew) =
0 commit comments