1+ """
2+ Base classes for active learning algorithms
3+ ------------------------------------------
4+ """
5+
6+ import abc
7+ import sys
8+ from typing import Union , Callable , Optional , Tuple , List , Iterator , Any
9+
10+ import numpy as np
11+ from sklearn .base import BaseEstimator
12+ from sklearn .utils import check_X_y
13+
14+ from modAL .utils .data import data_vstack , modALinput
15+
16+
17+ if sys .version_info >= (3 , 4 ):
18+ ABC = abc .ABC
19+ else :
20+ ABC = abc .ABCMeta ('ABC' , (), {})
21+
22+
23+ class BaseLearner (ABC , BaseEstimator ):
24+ """
25+ Core abstraction in modAL.
26+
27+ Args:
28+ estimator: The estimator to be used in the active learning loop.
29+ query_strategy: Function providing the query strategy for the active learning loop,
30+ for instance, modAL.uncertainty.uncertainty_sampling.
31+ X_training: Initial training samples, if available.
32+ y_training: Initial training labels corresponding to initial training samples.
33+ bootstrap_init: If initial training data is available, bootstrapping can be done during the first training.
34+ Useful when building Committee models with bagging.
35+ **fit_kwargs: keyword arguments.
36+
37+ Attributes:
38+ estimator: The estimator to be used in the active learning loop.
39+ query_strategy: Function providing the query strategy for the active learning loop.
40+ X_training: If the model hasn't been fitted yet it is None, otherwise it contains the samples
41+ which the model has been trained on.
42+ y_training: The labels corresponding to X_training.
43+ """
44+ def __init__ (self ,
45+ estimator : BaseEstimator ,
46+ query_strategy : Callable ,
47+ X_training : Optional [modALinput ] = None ,
48+ y_training : Optional [modALinput ] = None ,
49+ bootstrap_init : bool = False ,
50+ ** fit_kwargs
51+ ) -> None :
52+ assert callable (query_strategy ), 'query_strategy must be callable'
53+
54+ self .estimator = estimator
55+ self .query_strategy = query_strategy
56+
57+ self .X_training = X_training
58+ self .y_training = y_training
59+ if X_training is not None :
60+ self ._fit_to_known (bootstrap = bootstrap_init , ** fit_kwargs )
61+
62+ def _add_training_data (self , X : modALinput , y : modALinput ) -> None :
63+ """
64+ Adds the new data and label to the known data, but does not retrain the model.
65+
66+ Args:
67+ X: The new samples for which the labels are supplied by the expert.
68+ y: Labels corresponding to the new instances in X.
69+
70+ Note:
71+ If the classifier has been fitted, the features in X have to agree with the training samples which the
72+ classifier has seen.
73+ """
74+ check_X_y (X , y , accept_sparse = True , ensure_2d = False , allow_nd = True , multi_output = True )
75+
76+ if self .X_training is None :
77+ self .X_training = X
78+ self .y_training = y
79+ else :
80+ try :
81+ self .X_training = data_vstack ((self .X_training , X ))
82+ self .y_training = data_vstack ((self .y_training , y ))
83+ except ValueError :
84+ raise ValueError ('the dimensions of the new training data and label must'
85+ 'agree with the training data and labels provided so far' )
86+
87+ def _fit_to_known (self , bootstrap : bool = False , ** fit_kwargs ) -> 'BaseLearner' :
88+ """
89+ Fits self.estimator to the training data and labels provided to it so far.
90+
91+ Args:
92+ bootstrap: If True, the method trains the model on a set bootstrapped from the known training instances.
93+ **fit_kwargs: Keyword arguments to be passed to the fit method of the predictor.
94+
95+ Returns:
96+ self
97+ """
98+ if not bootstrap :
99+ self .estimator .fit (self .X_training , self .y_training , ** fit_kwargs )
100+ else :
101+ n_instances = self .X_training .shape [0 ]
102+ bootstrap_idx = np .random .choice (range (n_instances ), n_instances , replace = True )
103+ self .estimator .fit (self .X_training [bootstrap_idx ], self .y_training [bootstrap_idx ], ** fit_kwargs )
104+
105+ return self
106+
107+ def _fit_on_new (self , X : modALinput , y : modALinput , bootstrap : bool = False , ** fit_kwargs ) -> 'BaseLearner' :
108+ """
109+ Fits self.estimator to the given data and labels.
110+
111+ Args:
112+ X: The new samples for which the labels are supplied by the expert.
113+ y: Labels corresponding to the new instances in X.
114+ bootstrap: If True, the method trains the model on a set bootstrapped from X.
115+ **fit_kwargs: Keyword arguments to be passed to the fit method of the predictor.
116+
117+ Returns:
118+ self
119+ """
120+ check_X_y (X , y , accept_sparse = True , ensure_2d = False , allow_nd = True , multi_output = True )
121+
122+ if not bootstrap :
123+ self .estimator .fit (X , y , ** fit_kwargs )
124+ else :
125+ bootstrap_idx = np .random .choice (range (X .shape [0 ]), X .shape [0 ], replace = True )
126+ self .estimator .fit (X [bootstrap_idx ], y [bootstrap_idx ])
127+
128+ return self
129+
130+ def fit (self , X : modALinput , y : modALinput , bootstrap : bool = False , ** fit_kwargs ) -> 'BaseLearner' :
131+ """
132+ Interface for the fit method of the predictor. Fits the predictor to the supplied data, then stores it
133+ internally for the active learning loop.
134+
135+ Args:
136+ X: The samples to be fitted.
137+ y: The corresponding labels.
138+ bootstrap: If true, trains the estimator on a set bootstrapped from X.
139+ Useful for building Committee models with bagging.
140+ **fit_kwargs: Keyword arguments to be passed to the fit method of the predictor.
141+
142+ Note:
143+ When using scikit-learn estimators, calling this method will make the ActiveLearner forget all training data
144+ it has seen!
145+
146+ Returns:
147+ self
148+ """
149+ check_X_y (X , y , accept_sparse = True , ensure_2d = False , allow_nd = True , multi_output = True )
150+ self .X_training , self .y_training = X , y
151+ return self ._fit_to_known (bootstrap = bootstrap , ** fit_kwargs )
152+
153+ def predict (self , X : modALinput , ** predict_kwargs ) -> Any :
154+ """
155+ Estimator predictions for X. Interface with the predict method of the estimator.
156+
157+ Args:
158+ X: The samples to be predicted.
159+ **predict_kwargs: Keyword arguments to be passed to the predict method of the estimator.
160+
161+ Returns:
162+ Estimator predictions for X.
163+ """
164+ return self .estimator .predict (X , ** predict_kwargs )
165+
166+ def predict_proba (self , X : modALinput , ** predict_proba_kwargs ) -> Any :
167+ """
168+ Class probabilities if the predictor is a classifier. Interface with the predict_proba method of the classifier.
169+
170+ Args:
171+ X: The samples for which the class probabilities are to be predicted.
172+ **predict_proba_kwargs: Keyword arguments to be passed to the predict_proba method of the classifier.
173+
174+ Returns:
175+ Class probabilities for X.
176+ """
177+ return self .estimator .predict_proba (X , ** predict_proba_kwargs )
178+
179+ def query (self , * query_args , ** query_kwargs ) -> Union [Tuple , modALinput ]:
180+ """
181+ Finds the n_instances most informative point in the data provided by calling the query_strategy function.
182+
183+ Args:
184+ *query_args: The arguments for the query strategy. For instance, in the case of
185+ :func:`~modAL.uncertainty.uncertainty_sampling`, it is the pool of samples from which the query strategy
186+ should choose instances to request labels.
187+ **query_kwargs: Keyword arguments for the query strategy function.
188+
189+ Returns:
190+ Value of the query_strategy function. Should be the indices of the instances from the pool chosen to be
191+ labelled and the instances themselves. Can be different in other cases, for instance only the instance to be
192+ labelled upon query synthesis.
193+ """
194+ query_result = self .query_strategy (self , * query_args , ** query_kwargs )
195+ return query_result
196+
197+ def score (self , X : modALinput , y : modALinput , ** score_kwargs ) -> Any :
198+ """
199+ Interface for the score method of the predictor.
200+
201+ Args:
202+ X: The samples for which prediction accuracy is to be calculated.
203+ y: Ground truth labels for X.
204+ **score_kwargs: Keyword arguments to be passed to the .score() method of the predictor.
205+
206+ Returns:
207+ The score of the predictor.
208+ """
209+ return self .estimator .score (X , y , ** score_kwargs )
210+
211+ @abc .abstractmethod
212+ def teach (self , * args , ** kwargs ) -> None :
213+ pass
214+
215+
216+ class BaseCommittee (ABC , BaseEstimator ):
217+ """
218+ Base class for query-by-committee setup.
219+
220+ Args:
221+ learner_list: List of ActiveLearner objects to form committee.
222+ query_strategy: Function to query labels.
223+ """
224+ def __init__ (self , learner_list : List [BaseLearner ], query_strategy : Callable ) -> None :
225+ assert type (learner_list ) == list , 'learners must be supplied in a list'
226+
227+ self .learner_list = learner_list
228+ self .query_strategy = query_strategy
229+
230+ def __iter__ (self ) -> Iterator [BaseLearner ]:
231+ for learner in self .learner_list :
232+ yield learner
233+
234+ def __len__ (self ) -> int :
235+ return len (self .learner_list )
236+
237+ def _add_training_data (self , X : modALinput , y : modALinput ) -> None :
238+ """
239+ Adds the new data and label to the known data for each learner, but does not retrain the model.
240+
241+ Args:
242+ X: The new samples for which the labels are supplied by the expert.
243+ y: Labels corresponding to the new instances in X.
244+
245+ Note:
246+ If the learners have been fitted, the features in X have to agree with the training samples which the
247+ classifier has seen.
248+ """
249+ for learner in self .learner_list :
250+ learner ._add_training_data (X , y )
251+
252+ def _fit_to_known (self , bootstrap : bool = False , ** fit_kwargs ) -> None :
253+ """
254+ Fits all learners to the training data and labels provided to it so far.
255+
256+ Args:
257+ bootstrap: If True, each estimator is trained on a bootstrapped dataset. Useful when
258+ using bagging to build the ensemble.
259+ **fit_kwargs: Keyword arguments to be passed to the fit method of the predictor.
260+ """
261+ for learner in self .learner_list :
262+ learner ._fit_to_known (bootstrap = bootstrap , ** fit_kwargs )
263+
264+ def _fit_on_new (self , X : modALinput , y : modALinput , bootstrap : bool = False , ** fit_kwargs ) -> None :
265+ """
266+ Fits all learners to the given data and labels.
267+
268+ Args:
269+ X: The new samples for which the labels are supplied by the expert.
270+ y: Labels corresponding to the new instances in X.
271+ bootstrap: If True, the method trains the model on a set bootstrapped from X.
272+ **fit_kwargs: Keyword arguments to be passed to the fit method of the predictor.
273+ """
274+ for learner in self .learner_list :
275+ learner ._fit_on_new (X , y , bootstrap = bootstrap , ** fit_kwargs )
276+
277+ def fit (self , X : modALinput , y : modALinput , ** fit_kwargs ) -> 'BaseCommittee' :
278+ """
279+ Fits every learner to a subset sampled with replacement from X. Calling this method makes the learner forget the
280+ data it has seen up until this point and replaces it with X! If you would like to perform bootstrapping on each
281+ learner using the data it has seen, use the method .rebag()!
282+
283+ Calling this method makes the learner forget the data it has seen up until this point and replaces it with X!
284+
285+ Args:
286+ X: The samples to be fitted on.
287+ y: The corresponding labels.
288+ **fit_kwargs: Keyword arguments to be passed to the fit method of the predictor.
289+ """
290+ for learner in self .learner_list :
291+ learner .fit (X , y , ** fit_kwargs )
292+
293+ return self
294+
295+ def query (self , * query_args , ** query_kwargs ) -> Union [Tuple , modALinput ]:
296+ """
297+ Finds the n_instances most informative point in the data provided by calling the query_strategy function.
298+
299+ Args:
300+ *query_args: The arguments for the query strategy. For instance, in the case of
301+ :func:`~modAL.disagreement.max_disagreement_sampling`, it is the pool of samples from which the query.
302+ strategy should choose instances to request labels.
303+ **query_kwargs: Keyword arguments for the query strategy function.
304+
305+ Returns:
306+ Return value of the query_strategy function. Should be the indices of the instances from the pool chosen to
307+ be labelled and the instances themselves. Can be different in other cases, for instance only the instance to
308+ be labelled upon query synthesis.
309+ """
310+ query_result = self .query_strategy (self , * query_args , ** query_kwargs )
311+ return query_result
312+
313+ def rebag (self , ** fit_kwargs ) -> None :
314+ """
315+ Refits every learner with a dataset bootstrapped from its training instances. Contrary to .bag(), it bootstraps
316+ the training data for each learner based on its own examples.
317+
318+ Todo:
319+ Where is .bag()?
320+
321+ Args:
322+ **fit_kwargs: Keyword arguments to be passed to the fit method of the predictor.
323+ """
324+ self ._fit_to_known (bootstrap = True , ** fit_kwargs )
325+
326+ def teach (self , X : modALinput , y : modALinput , bootstrap : bool = False , only_new : bool = False , ** fit_kwargs ) -> None :
327+ """
328+ Adds X and y to the known training data for each learner and retrains learners with the augmented dataset.
329+
330+ Args:
331+ X: The new samples for which the labels are supplied by the expert.
332+ y: Labels corresponding to the new instances in X.
333+ bootstrap: If True, trains each learner on a bootstrapped set. Useful when building the ensemble by bagging.
334+ only_new: If True, the model is retrained using only X and y, ignoring the previously provided examples.
335+ **fit_kwargs: Keyword arguments to be passed to the fit method of the predictor.
336+ """
337+ self ._add_training_data (X , y )
338+ if not only_new :
339+ self ._fit_to_known (bootstrap = bootstrap , ** fit_kwargs )
340+ else :
341+ self ._fit_on_new (X , y , bootstrap = bootstrap , ** fit_kwargs )
342+
343+ @abc .abstractmethod
344+ def predict (self , X : modALinput ) -> Any :
345+ pass
346+
347+ @abc .abstractmethod
348+ def vote (self , X : modALinput ) -> Any : # TODO: clarify typing
349+ pass
0 commit comments