Skip to content

Commit ff8356b

Browse files
author
Max Keller
committed
Remove DeepCommittee
1 parent 4c8ad5d commit ff8356b

File tree

3 files changed

+94
-300
lines changed

3 files changed

+94
-300
lines changed

modAL/models/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
from .learners import ActiveLearner, DeepActiveLearner, BayesianOptimizer, Committee, DeepCommittee, CommitteeRegressor
1+
from .learners import ActiveLearner, DeepActiveLearner, BayesianOptimizer, Committee, CommitteeRegressor
22

33
__all__ = [
44
'ActiveLearner', 'DeepActiveLearner', 'BayesianOptimizer',
5-
'Committee', 'DeepCommittee', 'CommitteeRegressor'
5+
'Committee', 'CommitteeRegressor'
66
]

modAL/models/base.py

Lines changed: 67 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -212,20 +212,20 @@ def teach(self, *args, **kwargs) -> None:
212212
class BaseCommittee(ABC, BaseEstimator):
213213
"""
214214
Base class for query-by-committee setup.
215-
216215
Args:
217216
learner_list: List of ActiveLearner objects to form committee.
218217
query_strategy: Function to query labels.
219218
on_transformed: Whether to transform samples with the pipeline defined by each learner's estimator
220219
when applying the query strategy.
221220
"""
222-
223221
def __init__(self, learner_list: List[BaseLearner], query_strategy: Callable, on_transformed: bool = False) -> None:
224222
assert type(learner_list) == list, 'learners must be supplied in a list'
225223

226224
self.learner_list = learner_list
227225
self.query_strategy = query_strategy
228226
self.on_transformed = on_transformed
227+
# TODO: update training data when using fit() and teach() methods
228+
self.X_training = None
229229

230230
def __iter__(self) -> Iterator[BaseLearner]:
231231
for learner in self.learner_list:
@@ -234,10 +234,33 @@ def __iter__(self) -> Iterator[BaseLearner]:
234234
def __len__(self) -> int:
235235
return len(self.learner_list)
236236

237+
def _add_training_data(self, X: modALinput, y: modALinput) -> None:
238+
"""
239+
Adds the new data and label to the known data for each learner, but does not retrain the model.
240+
Args:
241+
X: The new samples for which the labels are supplied by the expert.
242+
y: Labels corresponding to the new instances in X.
243+
Note:
244+
If the learners have been fitted, the features in X have to agree with the training samples which the
245+
classifier has seen.
246+
"""
247+
for learner in self.learner_list:
248+
learner._add_training_data(X, y)
249+
250+
def _fit_to_known(self, bootstrap: bool = False, **fit_kwargs) -> None:
251+
"""
252+
Fits all learners to the training data and labels provided to it so far.
253+
Args:
254+
bootstrap: If True, each estimator is trained on a bootstrapped dataset. Useful when
255+
using bagging to build the ensemble.
256+
**fit_kwargs: Keyword arguments to be passed to the fit method of the predictor.
257+
"""
258+
for learner in self.learner_list:
259+
learner._fit_to_known(bootstrap=bootstrap, **fit_kwargs)
260+
237261
def _fit_on_new(self, X: modALinput, y: modALinput, bootstrap: bool = False, **fit_kwargs) -> None:
238262
"""
239263
Fits all learners to the given data and labels.
240-
241264
Args:
242265
X: The new samples for which the labels are supplied by the expert.
243266
y: Labels corresponding to the new instances in X.
@@ -247,16 +270,27 @@ def _fit_on_new(self, X: modALinput, y: modALinput, bootstrap: bool = False, **f
247270
for learner in self.learner_list:
248271
learner._fit_on_new(X, y, bootstrap=bootstrap, **fit_kwargs)
249272

250-
@abc.abstractmethod
251-
def predict(self, X: modALinput) -> Any:
252-
pass
273+
def fit(self, X: modALinput, y: modALinput, **fit_kwargs) -> 'BaseCommittee':
274+
"""
275+
Fits every learner to a subset sampled with replacement from X. Calling this method makes the learner forget the
276+
data it has seen up until this point and replaces it with X! If you would like to perform bootstrapping on each
277+
learner using the data it has seen, use the method .rebag()!
278+
Calling this method makes the learner forget the data it has seen up until this point and replaces it with X!
279+
Args:
280+
X: The samples to be fitted on.
281+
y: The corresponding labels.
282+
**fit_kwargs: Keyword arguments to be passed to the fit method of the predictor.
283+
"""
284+
for learner in self.learner_list:
285+
learner.fit(X, y, **fit_kwargs)
286+
287+
return self
253288

254289
def transform_without_estimating(self, X: modALinput) -> Union[np.ndarray, sp.csr_matrix]:
255290
"""
256291
Transforms the data as supplied to each learner's estimator and concatenates transformations.
257292
Args:
258293
X: dataset to be transformed
259-
260294
Returns:
261295
Transformed data set
262296
"""
@@ -298,32 +332,38 @@ def query(self, X_pool, return_metrics: bool = False, *query_args, **query_kwarg
298332
else:
299333
return query_result, retrieve_rows(X_pool, query_result)
300334

301-
def _set_classes(self):
335+
def rebag(self, **fit_kwargs) -> None:
302336
"""
303-
Checks the known class labels by each learner, merges the labels and returns a mapping which maps the learner's
304-
classes to the complete label list.
337+
Refits every learner with a dataset bootstrapped from its training instances. Contrary to .bag(), it bootstraps
338+
the training data for each learner based on its own examples.
339+
Todo:
340+
Where is .bag()?
341+
Args:
342+
**fit_kwargs: Keyword arguments to be passed to the fit method of the predictor.
305343
"""
306-
# assemble the list of known classes from each learner
307-
try:
308-
# if estimators are fitted
309-
known_classes = tuple(
310-
learner.estimator.classes_ for learner in self.learner_list)
311-
except AttributeError:
312-
# handle unfitted estimators
313-
self.classes_ = None
314-
self.n_classes_ = 0
315-
return
316-
317-
self.classes_ = np.unique(
318-
np.concatenate(known_classes, axis=0),
319-
axis=0
320-
)
321-
self.n_classes_ = len(self.classes_)
344+
self._fit_to_known(bootstrap=True, **fit_kwargs)
322345

346+
def teach(self, X: modALinput, y: modALinput, bootstrap: bool = False, only_new: bool = False, **fit_kwargs) -> None:
347+
"""
348+
Adds X and y to the known training data for each learner and retrains learners with the augmented dataset.
349+
Args:
350+
X: The new samples for which the labels are supplied by the expert.
351+
y: Labels corresponding to the new instances in X.
352+
bootstrap: If True, trains each learner on a bootstrapped set. Useful when building the ensemble by bagging.
353+
only_new: If True, the model is retrained using only X and y, ignoring the previously provided examples.
354+
**fit_kwargs: Keyword arguments to be passed to the fit method of the predictor.
355+
"""
356+
self._add_training_data(X, y)
357+
if not only_new:
358+
self._fit_to_known(bootstrap=bootstrap, **fit_kwargs)
359+
else:
360+
self._fit_on_new(X, y, bootstrap=bootstrap, **fit_kwargs)
361+
362+
@abc.abstractmethod
363+
def predict(self, X: modALinput) -> Any:
323364
pass
324365

325366
@abc.abstractmethod
326367
def vote(self, X: modALinput) -> Any: # TODO: clarify typing
327368
pass
328369

329-

0 commit comments

Comments
 (0)