Skip to content

Commit 1e9d5e5

Browse files
committed
rename a struct in a test
1 parent 55caed4 commit 1e9d5e5

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

test/integration/gradient_descent.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ PerceptronClassifier(; epochs=50, optimiser=Optimisers.Adam(), rng=Random.defaul
134134
LearnAPI.target(algorithm::PerceptronClassifier, data::Tuple) = last(data)
135135

136136
# For wrapping pre-processed training data (output of `obs(algorithm, data)`):
137-
struct PerceptronClassifierObservations
137+
struct PerceptronClassifierObs
138138
X::Matrix{Float32}
139139
y_hot::BitMatrix # one-hot encoded target
140140
classes # the (ordered) pool of `y`, as `CategoricalValue`s
@@ -145,25 +145,25 @@ function LearnAPI.obs(algorithm::PerceptronClassifier, data::Tuple)
145145
X, y = data
146146
classes = CategoricalDistributions.classes(y)
147147
y_hot = classes .== permutedims(y) # one-hot encoding
148-
return PerceptronClassifierObservations(X, y_hot, classes)
148+
return PerceptronClassifierObs(X, y_hot, classes)
149149
end
150150

151151
# implement `RadomAccess()` interface for output of `obs`:
152-
Base.length(observations::PerceptronClassifierObservations) = length(observations.y)
153-
Base.getindex(observations, I) = PerceptronClassifierObservations(
152+
Base.length(observations::PerceptronClassifierObs) = length(observations.y)
153+
Base.getindex(observations, I) = PerceptronClassifierObs(
154154
(@view observations.X[:, I]),
155155
(@view observations.y[I]),
156156
observations.classes,
157157
)
158158

159159
LearnAPI.target(
160160
algorithm::PerceptronClassifier,
161-
observations::PerceptronClassifierObservations,
161+
observations::PerceptronClassifierObs,
162162
) = observations.y
163163

164164
LearnAPI.features(
165165
algorithm::PerceptronClassifier,
166-
observations::PerceptronClassifierObservations,
166+
observations::PerceptronClassifierObs,
167167
) = observations.X
168168

169169
# Note that data consumed by `predict` needs no pre-processing, so no need to overload
@@ -186,7 +186,7 @@ LearnAPI.algorithm(model::PerceptronClassifierFitted) = model.algorithm
186186
# `fit` for pre-processed data (output of `obs(algorithm, data)`):
187187
function LearnAPI.fit(
188188
algorithm::PerceptronClassifier,
189-
observations::PerceptronClassifierObservations;
189+
observations::PerceptronClassifierObs;
190190
verbosity=1,
191191
)
192192

@@ -221,7 +221,7 @@ LearnAPI.fit(algorithm::PerceptronClassifier, data; kwargs...) =
221221
# see the `PerceptronClassifier` docstring for `update_observations` logic.
222222
function LearnAPI.update_observations(
223223
model::PerceptronClassifierFitted,
224-
observations_new::PerceptronClassifierObservations;
224+
observations_new::PerceptronClassifierObs;
225225
verbosity=1,
226226
replacements...,
227227
)
@@ -253,7 +253,7 @@ LearnAPI.update_observations(model::PerceptronClassifierFitted, data; kwargs...)
253253
# see the `PerceptronClassifier` docstring for `update` logic.
254254
function LearnAPI.update(
255255
model::PerceptronClassifierFitted,
256-
observations::PerceptronClassifierObservations;
256+
observations::PerceptronClassifierObs;
257257
verbosity=1,
258258
replacements...,
259259
)

0 commit comments

Comments
 (0)