@@ -134,7 +134,7 @@ PerceptronClassifier(; epochs=50, optimiser=Optimisers.Adam(), rng=Random.defaul
134
134
LearnAPI. target (algorithm:: PerceptronClassifier , data:: Tuple ) = last (data)
135
135
136
136
# For wrapping pre-processed training data (output of `obs(algorithm, data)`):
137
- struct PerceptronClassifierObservations
137
+ struct PerceptronClassifierObs
138
138
X:: Matrix{Float32}
139
139
y_hot:: BitMatrix # one-hot encoded target
140
140
classes # the (ordered) pool of `y`, as `CategoricalValue`s
@@ -145,25 +145,25 @@ function LearnAPI.obs(algorithm::PerceptronClassifier, data::Tuple)
145
145
X, y = data
146
146
classes = CategoricalDistributions. classes (y)
147
147
y_hot = classes .== permutedims (y) # one-hot encoding
148
- return PerceptronClassifierObservations (X, y_hot, classes)
148
+ return PerceptronClassifierObs (X, y_hot, classes)
149
149
end
150
150
151
151
# implement `RadomAccess()` interface for output of `obs`:
152
- Base. length (observations:: PerceptronClassifierObservations ) = length (observations. y)
153
- Base. getindex (observations, I) = PerceptronClassifierObservations (
152
+ Base. length (observations:: PerceptronClassifierObs ) = length (observations. y)
153
+ Base. getindex (observations, I) = PerceptronClassifierObs (
154
154
(@view observations. X[:, I]),
155
155
(@view observations. y[I]),
156
156
observations. classes,
157
157
)
158
158
159
159
LearnAPI. target (
160
160
algorithm:: PerceptronClassifier ,
161
- observations:: PerceptronClassifierObservations ,
161
+ observations:: PerceptronClassifierObs ,
162
162
) = observations. y
163
163
164
164
LearnAPI. features (
165
165
algorithm:: PerceptronClassifier ,
166
- observations:: PerceptronClassifierObservations ,
166
+ observations:: PerceptronClassifierObs ,
167
167
) = observations. X
168
168
169
169
# Note that data consumed by `predict` needs no pre-processing, so no need to overload
@@ -186,7 +186,7 @@ LearnAPI.algorithm(model::PerceptronClassifierFitted) = model.algorithm
186
186
# `fit` for pre-processed data (output of `obs(algorithm, data)`):
187
187
function LearnAPI. fit (
188
188
algorithm:: PerceptronClassifier ,
189
- observations:: PerceptronClassifierObservations ;
189
+ observations:: PerceptronClassifierObs ;
190
190
verbosity= 1 ,
191
191
)
192
192
@@ -221,7 +221,7 @@ LearnAPI.fit(algorithm::PerceptronClassifier, data; kwargs...) =
221
221
# see the `PerceptronClassifier` docstring for `update_observations` logic.
222
222
function LearnAPI. update_observations (
223
223
model:: PerceptronClassifierFitted ,
224
- observations_new:: PerceptronClassifierObservations ;
224
+ observations_new:: PerceptronClassifierObs ;
225
225
verbosity= 1 ,
226
226
replacements... ,
227
227
)
@@ -253,7 +253,7 @@ LearnAPI.update_observations(model::PerceptronClassifierFitted, data; kwargs...)
253
253
# see the `PerceptronClassifier` docstring for `update` logic.
254
254
function LearnAPI. update (
255
255
model:: PerceptronClassifierFitted ,
256
- observations:: PerceptronClassifierObservations ;
256
+ observations:: PerceptronClassifierObs ;
257
257
verbosity= 1 ,
258
258
replacements... ,
259
259
)
0 commit comments