@@ -212,20 +212,20 @@ def teach(self, *args, **kwargs) -> None:
212212class BaseCommittee (ABC , BaseEstimator ):
213213 """
214214 Base class for query-by-committee setup.
215-
216215 Args:
217216 learner_list: List of ActiveLearner objects to form committee.
218217 query_strategy: Function to query labels.
219218 on_transformed: Whether to transform samples with the pipeline defined by each learner's estimator
220219 when applying the query strategy.
221220 """
222-
223221 def __init__ (self , learner_list : List [BaseLearner ], query_strategy : Callable , on_transformed : bool = False ) -> None :
224222 assert type (learner_list ) == list , 'learners must be supplied in a list'
225223
226224 self .learner_list = learner_list
227225 self .query_strategy = query_strategy
228226 self .on_transformed = on_transformed
227+ # TODO: update training data when using fit() and teach() methods
228+ self .X_training = None
229229
230230 def __iter__ (self ) -> Iterator [BaseLearner ]:
231231 for learner in self .learner_list :
@@ -234,10 +234,33 @@ def __iter__(self) -> Iterator[BaseLearner]:
234234 def __len__ (self ) -> int :
235235 return len (self .learner_list )
236236
237+ def _add_training_data (self , X : modALinput , y : modALinput ) -> None :
238+ """
239+ Adds the new data and label to the known data for each learner, but does not retrain the model.
240+ Args:
241+ X: The new samples for which the labels are supplied by the expert.
242+ y: Labels corresponding to the new instances in X.
243+ Note:
244+ If the learners have been fitted, the features in X have to agree with the training samples which the
245+ classifier has seen.
246+ """
247+ for learner in self .learner_list :
248+ learner ._add_training_data (X , y )
249+
250+ def _fit_to_known (self , bootstrap : bool = False , ** fit_kwargs ) -> None :
251+ """
252+ Fits all learners to the training data and labels provided to it so far.
253+ Args:
254+ bootstrap: If True, each estimator is trained on a bootstrapped dataset. Useful when
255+ using bagging to build the ensemble.
256+ **fit_kwargs: Keyword arguments to be passed to the fit method of the predictor.
257+ """
258+ for learner in self .learner_list :
259+ learner ._fit_to_known (bootstrap = bootstrap , ** fit_kwargs )
260+
237261 def _fit_on_new (self , X : modALinput , y : modALinput , bootstrap : bool = False , ** fit_kwargs ) -> None :
238262 """
239263 Fits all learners to the given data and labels.
240-
241264 Args:
242265 X: The new samples for which the labels are supplied by the expert.
243266 y: Labels corresponding to the new instances in X.
@@ -247,16 +270,27 @@ def _fit_on_new(self, X: modALinput, y: modALinput, bootstrap: bool = False, **f
247270 for learner in self .learner_list :
248271 learner ._fit_on_new (X , y , bootstrap = bootstrap , ** fit_kwargs )
249272
250- @abc .abstractmethod
251- def predict (self , X : modALinput ) -> Any :
252- pass
273+ def fit (self , X : modALinput , y : modALinput , ** fit_kwargs ) -> 'BaseCommittee' :
274+ """
275+ Fits every learner to a subset sampled with replacement from X. Calling this method makes the learner forget the
276+ data it has seen up until this point and replaces it with X! If you would like to perform bootstrapping on each
277+ learner using the data it has seen, use the method .rebag()!
278+ Calling this method makes the learner forget the data it has seen up until this point and replaces it with X!
279+ Args:
280+ X: The samples to be fitted on.
281+ y: The corresponding labels.
282+ **fit_kwargs: Keyword arguments to be passed to the fit method of the predictor.
283+ """
284+ for learner in self .learner_list :
285+ learner .fit (X , y , ** fit_kwargs )
286+
287+ return self
253288
254289 def transform_without_estimating (self , X : modALinput ) -> Union [np .ndarray , sp .csr_matrix ]:
255290 """
256291 Transforms the data as supplied to each learner's estimator and concatenates transformations.
257292 Args:
258293 X: dataset to be transformed
259-
260294 Returns:
261295 Transformed data set
262296 """
@@ -298,32 +332,38 @@ def query(self, X_pool, return_metrics: bool = False, *query_args, **query_kwarg
298332 else :
299333 return query_result , retrieve_rows (X_pool , query_result )
300334
301- def _set_classes (self ) :
335+ def rebag (self , ** fit_kwargs ) -> None :
302336 """
303- Checks the known class labels by each learner, merges the labels and returns a mapping which maps the learner's
304- classes to the complete label list.
337+ Refits every learner with a dataset bootstrapped from its training instances. Contrary to .bag(), it bootstraps
338+ the training data for each learner based on its own examples.
339+ Todo:
340+ Where is .bag()?
341+ Args:
342+ **fit_kwargs: Keyword arguments to be passed to the fit method of the predictor.
305343 """
306- # assemble the list of known classes from each learner
307- try :
308- # if estimators are fitted
309- known_classes = tuple (
310- learner .estimator .classes_ for learner in self .learner_list )
311- except AttributeError :
312- # handle unfitted estimators
313- self .classes_ = None
314- self .n_classes_ = 0
315- return
316-
317- self .classes_ = np .unique (
318- np .concatenate (known_classes , axis = 0 ),
319- axis = 0
320- )
321- self .n_classes_ = len (self .classes_ )
344+ self ._fit_to_known (bootstrap = True , ** fit_kwargs )
322345
346+ def teach (self , X : modALinput , y : modALinput , bootstrap : bool = False , only_new : bool = False , ** fit_kwargs ) -> None :
347+ """
348+ Adds X and y to the known training data for each learner and retrains learners with the augmented dataset.
349+ Args:
350+ X: The new samples for which the labels are supplied by the expert.
351+ y: Labels corresponding to the new instances in X.
352+ bootstrap: If True, trains each learner on a bootstrapped set. Useful when building the ensemble by bagging.
353+ only_new: If True, the model is retrained using only X and y, ignoring the previously provided examples.
354+ **fit_kwargs: Keyword arguments to be passed to the fit method of the predictor.
355+ """
356+ self ._add_training_data (X , y )
357+ if not only_new :
358+ self ._fit_to_known (bootstrap = bootstrap , ** fit_kwargs )
359+ else :
360+ self ._fit_on_new (X , y , bootstrap = bootstrap , ** fit_kwargs )
361+
362+ @abc .abstractmethod
363+ def predict (self , X : modALinput ) -> Any :
323364 pass
324365
325366 @abc .abstractmethod
326367 def vote (self , X : modALinput ) -> Any : # TODO: clarify typing
327368 pass
328369
329-
0 commit comments