Skip to content

Commit abc0f94

Browse files
committed
address a breaking change in LearnDataFrontEnds.jl
1 parent d84a83d commit abc0f94

File tree

2 files changed

+30
-32
lines changed

2 files changed

+30
-32
lines changed

src/learners/classification.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ struct ConstantClassifierFitted
2424
learner::ConstantClassifier
2525
probabilities
2626
names::Vector{Symbol}
27-
classes_seen
27+
levels_seen
2828
codes_seen
2929
decoder
3030
end
@@ -52,7 +52,7 @@ function LearnAPI.fit(
5252

5353
y = observations.target # integer "codes"
5454
names = observations.names
55-
classes_seen = observations.classes_seen
55+
levels_seen = observations.levels_seen
5656
codes_seen = sort(unique(y))
5757
decoder = observations.decoder
5858

@@ -64,7 +64,7 @@ function LearnAPI.fit(
6464
learner,
6565
probabilities,
6666
names,
67-
classes_seen,
67+
levels_seen,
6868
codes_seen,
6969
decoder,
7070
)
@@ -94,7 +94,7 @@ function LearnAPI.predict(
9494
probs = model.probabilities
9595
# repeat vertically to get rows of a matrix:
9696
probs_matrix = reshape(repeat(probs, n), (length(probs), n))'
97-
return CategoricalDistributions.UnivariateFinite(model.classes_seen, probs_matrix)
97+
return CategoricalDistributions.UnivariateFinite(model.levels_seen, probs_matrix)
9898
end
9999
LearnAPI.predict(model::ConstantClassifierFitted, ::Distribution, data) =
100100
predict(model, Distribution(), obs(model, data))

src/learners/gradient_descent.jl

Lines changed: 26 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
11
# THIS FILE IS NOT INCLUDED BUT EXISTS AS AN IMPLEMENTATION EXEMPLAR
22

3-
# TODO: This file should be updated after release of CategoricalDistributions 0.2 as
4-
# `classes` is deprecated there.
5-
63
# This file defines:
74
# - `PerceptronClassifier(; epochs=50, optimiser=Optimisers.Adam(), rng=Random.default_rng())
85

@@ -13,6 +10,7 @@ using StableRNGs
1310
import Optimisers
1411
import Zygote
1512
import NNlib
13+
import CategoricalArrays
1614
import CategoricalDistributions
1715
import CategoricalDistributions: pdf, mode
1816
import ComponentArrays
@@ -58,7 +56,7 @@ for the specified number of `epochs`.
5856
- `perceptron`: component array with components `weights` and `bias`
5957
- `optimiser`: optimiser from Optimiser.jl
6058
- `X`: feature matrix, of size `(p, n)`
61-
- `y_hot`: one-hot encoded target, of size `(nclasses, n)`
59+
- `y_hot`: one-hot encoded target, of size `(nlevels, n)`
6260
- `epochs`: number of epochs
6361
- `state`: optimiser state
6462
@@ -135,38 +133,38 @@ PerceptronClassifier(; epochs=50, optimiser=Optimisers.Adam(), rng=Random.defaul
135133
struct PerceptronClassifierObs
136134
X::Matrix{Float32}
137135
y_hot::BitMatrix # one-hot encoded target
138-
classes # the (ordered) pool of `y`, as `CategoricalValue`s
136+
levels # the (ordered) pool of `y`, as `CategoricalValue`s
139137
end
140138

141139
# For pre-processing the training data:
142140
function LearnAPI.obs(::PerceptronClassifier, data::Tuple)
143141
X, y = data
144-
classes = CategoricalDistributions.classes(y)
145-
y_hot = classes .== permutedims(y) # one-hot encoding
146-
return PerceptronClassifierObs(X, y_hot, classes)
142+
levels = CategoricalArrays.levels(y)
143+
y_hot = levels .== permutedims(y) # one-hot encoding
144+
return PerceptronClassifierObs(X, y_hot, levels)
147145
end
148146
LearnAPI.obs(::PerceptronClassifier, observations::PerceptronClassifierObs) =
149147
observations # involutivity
150148

151149
# helper:
152-
function decode(y_hot, classes)
150+
function decode(y_hot, levels)
153151
n = size(y_hot, 2)
154-
[only(classes[y_hot[:,i]]) for i in 1:n]
152+
[only(levels[y_hot[:,i]]) for i in 1:n]
155153
end
156154

157155
# implement `RadomAccess()` interface for output of `obs`:
158156
Base.length(observations::PerceptronClassifierObs) = size(observations.y_hot, 2)
159157
Base.getindex(observations::PerceptronClassifierObs, I) = PerceptronClassifierObs(
160158
observations.X[:, I],
161159
observations.y_hot[:, I],
162-
observations.classes,
160+
observations.levels,
163161
)
164162

165163
# training data deconstructors:
166164
LearnAPI.target(
167165
learner::PerceptronClassifier,
168166
observations::PerceptronClassifierObs,
169-
) = decode(observations.y_hot, observations.classes)
167+
) = decode(observations.y_hot, observations.levels)
170168
LearnAPI.target(learner::PerceptronClassifier, data) =
171169
LearnAPI.target(learner, obs(learner, data))
172170
LearnAPI.features(
@@ -187,7 +185,7 @@ struct PerceptronClassifierFitted
187185
learner::PerceptronClassifier
188186
perceptron # component array storing weights and bias
189187
state # optimiser state
190-
classes # target classes
188+
levels # target levels
191189
losses
192190
end
193191

@@ -208,20 +206,20 @@ function LearnAPI.fit(
208206
# unpack data:
209207
X = observations.X
210208
y_hot = observations.y_hot
211-
classes = observations.classes
212-
nclasses = length(classes)
209+
levels = observations.levels
210+
nlevels = length(levels)
213211

214212
# initialize bias and weights:
215-
weights = randn(rng, Float32, nclasses, p)
216-
bias = zeros(Float32, nclasses)
213+
weights = randn(rng, Float32, nlevels, p)
214+
bias = zeros(Float32, nlevels)
217215
perceptron = (; weights, bias) |> ComponentArrays.ComponentArray
218216

219217
# initialize optimiser:
220218
state = Optimisers.setup(optimiser, perceptron)
221219

222220
perceptron, state, losses = corefit(perceptron, X, y_hot, epochs, state, verbosity)
223221

224-
return PerceptronClassifierFitted(learner, perceptron, state, classes, losses)
222+
return PerceptronClassifierFitted(learner, perceptron, state, levels, losses)
225223
end
226224

227225
# `fit` for unprocessed data:
@@ -239,10 +237,10 @@ function LearnAPI.update_observations(
239237
# unpack data:
240238
X = observations_new.X
241239
y_hot = observations_new.y_hot
242-
classes = observations_new.classes
243-
nclasses = length(classes)
240+
levels = observations_new.levels
241+
nlevels = length(levels)
244242

245-
classes == model.classes || error("New training target has incompatible classes.")
243+
levels == model.levels || error("New training target has incompatible levels.")
246244

247245
learner_old = LearnAPI.learner(model)
248246
learner = LearnAPI.clone(learner_old, replacements...)
@@ -255,7 +253,7 @@ function LearnAPI.update_observations(
255253
perceptron, state, losses_new = corefit(perceptron, X, y_hot, epochs, state, verbosity)
256254
losses = vcat(losses, losses_new)
257255

258-
return PerceptronClassifierFitted(learner, perceptron, state, classes, losses)
256+
return PerceptronClassifierFitted(learner, perceptron, state, levels, losses)
259257
end
260258
LearnAPI.update_observations(model::PerceptronClassifierFitted, data, args...; kwargs...) =
261259
update_observations(model, obs(LearnAPI.learner(model), data), args...; kwargs...)
@@ -271,10 +269,10 @@ function LearnAPI.update(
271269
# unpack data:
272270
X = observations.X
273271
y_hot = observations.y_hot
274-
classes = observations.classes
275-
nclasses = length(classes)
272+
levels = observations.levels
273+
nlevels = length(levels)
276274

277-
classes == model.classes || error("New training target has incompatible classes.")
275+
levels == model.levels || error("New training target has incompatible levels.")
278276

279277
learner_old = LearnAPI.learner(model)
280278
learner = LearnAPI.clone(learner_old, replacements...)
@@ -292,7 +290,7 @@ function LearnAPI.update(
292290
corefit(perceptron, X, y_hot, Δepochs, state, verbosity)
293291
losses = vcat(losses, losses_new)
294292

295-
return PerceptronClassifierFitted(learner, perceptron, state, classes, losses)
293+
return PerceptronClassifierFitted(learner, perceptron, state, levels, losses)
296294
end
297295
LearnAPI.update(model::PerceptronClassifierFitted, data, args...; kwargs...) =
298296
update(model, obs(LearnAPI.learner(model), data), args...; kwargs...)
@@ -302,9 +300,9 @@ LearnAPI.update(model::PerceptronClassifierFitted, data, args...; kwargs...) =
302300

303301
function LearnAPI.predict(model::PerceptronClassifierFitted, ::Distribution, Xnew)
304302
perceptron = model.perceptron
305-
classes = model.classes
303+
levels = model.levels
306304
probs = perceptron.weights*Xnew .+ perceptron.bias |> NNlib.softmax
307-
return CategoricalDistributions.UnivariateFinite(classes, probs')
305+
return CategoricalDistributions.UnivariateFinite(levels, probs')
308306
end
309307

310308
LearnAPI.predict(model::PerceptronClassifierFitted, ::Point, Xnew) =

0 commit comments

Comments
 (0)