Skip to content

Commit ff7a52f

Browse files
committed
Merge branch 'OskarLiew-fix/fit-empty-committee' into dev
2 parents c95fb1d + e5bddad commit ff7a52f

File tree

1 file changed

+16
-1
lines changed

1 file changed

+16
-1
lines changed

modAL/models/learners.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,22 @@ def _set_classes(self):
312312
def _add_training_data(self, X: modALinput, y: modALinput):
313313
super()._add_training_data(X, y)
314314

315+
def fit(self, X: modALinput, y: modALinput, **fit_kwargs) -> 'BaseCommittee':
316+
"""
317+
Fits every learner to a subset sampled with replacement from X. Calling this method makes the learner forget the
318+
data it has seen up until this point and replaces it with X! If you would like to perform bootstrapping on each
319+
learner using the data it has seen, use the method .rebag()!
320+
321+
Calling this method makes the learner forget the data it has seen up until this point and replaces it with X!
322+
323+
Args:
324+
X: The samples to be fitted on.
325+
y: The corresponding labels.
326+
**fit_kwargs: Keyword arguments to be passed to the fit method of the predictor.
327+
"""
328+
super().fit(X, y, **fit_kwargs)
329+
self._set_classes()
330+
315331
def teach(self, X: modALinput, y: modALinput, bootstrap: bool = False, only_new: bool = False, **fit_kwargs) -> None:
316332
"""
317333
Adds X and y to the known training data for each learner and retrains learners with the augmented dataset.
@@ -323,7 +339,6 @@ def teach(self, X: modALinput, y: modALinput, bootstrap: bool = False, only_new:
323339
only_new: If True, the model is retrained using only X and y, ignoring the previously provided examples.
324340
**fit_kwargs: Keyword arguments to be passed to the fit method of the predictor.
325341
"""
326-
327342
super().teach(X, y, bootstrap=bootstrap, only_new=only_new, **fit_kwargs)
328343
self._set_classes()
329344

0 commit comments

Comments
 (0)