Skip to content

Commit d792e51

Browse files
committed
fix a mistake in implementation of Ensemble
1 parent 083bae9 commit d792e51

File tree

1 file changed

+10
-8
lines changed

1 file changed

+10
-8
lines changed

test/integration/iterative_algorithms.jl

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ LearnAPI.algorithm(model::EnsembleFitted) = model.algorithm
5555
LearnAPI.obs(algorithm::Ensemble, data) = LearnAPI.obs(algorithm.atom, data)
5656
LearnAPI.obs(model::EnsembleFitted, data) = LearnAPI.obs(first(model.models), data)
5757
LearnAPI.target(algorithm::Ensemble, data) = LearnAPI.target(algorithm.atom, data)
58-
LearnAPI.features(algorithm::Ridge, data) = LearnAPI.features(algorithm.atom, data)
58+
LearnAPI.features(algorithm::Ensemble, data) = LearnAPI.features(algorithm.atom, data)
5959

6060
function LearnAPI.fit(algorithm::Ensemble, data; verbosity=1)
6161

@@ -97,10 +97,11 @@ end
9797
# models. Otherwise, update is equivalent to retraining from scratch, with the provided
9898
# hyperparameter updates.
9999
function LearnAPI.update(model::EnsembleFitted, data; verbosity=1, replacements...)
100-
:n in keys(replacements) || return fit(model, data)
101-
102100
algorithm_old = LearnAPI.algorithm(model)
103101
algorithm = LearnAPI.clone(algorithm_old; replacements...)
102+
103+
:n in keys(replacements) || return fit(algorithm, data)
104+
104105
n = algorithm.n
105106
Δn = n - algorithm_old.n
106107
n < 0 && return fit(model, algorithm)
@@ -156,7 +157,6 @@ LearnAPI.minimize(model::EnsembleFitted) = EnsembleFitted(
156157
:(LearnAPI.target),
157158
:(LearnAPI.update),
158159
:(LearnAPI.predict),
159-
:(LearnAPI.feature_importances),
160160
)
161161
)
162162

@@ -190,16 +190,18 @@ Xtest = Tables.subset(X, test)
190190
@test ŷ4 == predict(model, Xtest)
191191

192192
# add 3 atomic models to the ensemble:
193-
# model = @test_logs(
194-
# (:info, r"Trained 3 additional"),
195-
# update(model, Xtrain, y[train]; n=7),
196-
# )
197193
model = update(model, Xtrain, y[train]; verbosity=0, n=7);
198194
ŷ7 = predict(model, Xtest)
199195

200196
# compare with cold restart:
201197
model = fit(LearnAPI.clone(algorithm; n=7), Xtrain, y[train]; verbosity=0);
202198
@test ŷ7 predict(model, Xtest)
199+
200+
# test cold restart if another hyperparameter is changed:
201+
model2 = update(model, Xtrain, y[train]; atom=Ridge(0.05))
202+
algorithm2 = LearnAPI.clone(LearnAPI.algorithm(model); atom=Ridge(0.05))
203+
@test predict(model, Xtest) predict(model2, Xtest)
204+
203205
end
204206

205207
true

0 commit comments

Comments
 (0)