Skip to content

Commit 6df2fc1

Browse files
committed
refactor: abstract and regular classes are separated
1 parent b88c7fb commit 6df2fc1

File tree

4 files changed

+360
-351
lines changed

4 files changed

+360
-351
lines changed

modAL/batch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from sklearn.metrics.pairwise import pairwise_distances, pairwise_distances_argmin_min
1010

1111
from modAL.utils.data import data_vstack, modALinput
12-
from modAL.models import BaseCommittee, BaseLearner
12+
from modAL.models.base import BaseCommittee, BaseLearner
1313
from modAL.uncertainty import classifier_uncertainty
1414

1515

modAL/models/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from .learners import ActiveLearner, BayesianOptimizer, Committee, CommitteeRegressor
2+
3+
__all__ = [
4+
'ActiveLearner', 'BayesianOptimizer',
5+
'Committee', 'CommitteeRegressor'
6+
]

modAL/models/base.py

Lines changed: 349 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,349 @@
1+
"""
2+
Base classes for active learning algorithms
3+
------------------------------------------
4+
"""
5+
6+
import abc
7+
import sys
8+
from typing import Union, Callable, Optional, Tuple, List, Iterator, Any
9+
10+
import numpy as np
11+
from sklearn.base import BaseEstimator
12+
from sklearn.utils import check_X_y
13+
14+
from modAL.utils.data import data_vstack, modALinput
15+
16+
17+
if sys.version_info >= (3, 4):
18+
ABC = abc.ABC
19+
else:
20+
ABC = abc.ABCMeta('ABC', (), {})
21+
22+
23+
class BaseLearner(ABC, BaseEstimator):
24+
"""
25+
Core abstraction in modAL.
26+
27+
Args:
28+
estimator: The estimator to be used in the active learning loop.
29+
query_strategy: Function providing the query strategy for the active learning loop,
30+
for instance, modAL.uncertainty.uncertainty_sampling.
31+
X_training: Initial training samples, if available.
32+
y_training: Initial training labels corresponding to initial training samples.
33+
bootstrap_init: If initial training data is available, bootstrapping can be done during the first training.
34+
Useful when building Committee models with bagging.
35+
**fit_kwargs: keyword arguments.
36+
37+
Attributes:
38+
estimator: The estimator to be used in the active learning loop.
39+
query_strategy: Function providing the query strategy for the active learning loop.
40+
X_training: If the model hasn't been fitted yet it is None, otherwise it contains the samples
41+
which the model has been trained on.
42+
y_training: The labels corresponding to X_training.
43+
"""
44+
def __init__(self,
45+
estimator: BaseEstimator,
46+
query_strategy: Callable,
47+
X_training: Optional[modALinput] = None,
48+
y_training: Optional[modALinput] = None,
49+
bootstrap_init: bool = False,
50+
**fit_kwargs
51+
) -> None:
52+
assert callable(query_strategy), 'query_strategy must be callable'
53+
54+
self.estimator = estimator
55+
self.query_strategy = query_strategy
56+
57+
self.X_training = X_training
58+
self.y_training = y_training
59+
if X_training is not None:
60+
self._fit_to_known(bootstrap=bootstrap_init, **fit_kwargs)
61+
62+
def _add_training_data(self, X: modALinput, y: modALinput) -> None:
63+
"""
64+
Adds the new data and label to the known data, but does not retrain the model.
65+
66+
Args:
67+
X: The new samples for which the labels are supplied by the expert.
68+
y: Labels corresponding to the new instances in X.
69+
70+
Note:
71+
If the classifier has been fitted, the features in X have to agree with the training samples which the
72+
classifier has seen.
73+
"""
74+
check_X_y(X, y, accept_sparse=True, ensure_2d=False, allow_nd=True, multi_output=True)
75+
76+
if self.X_training is None:
77+
self.X_training = X
78+
self.y_training = y
79+
else:
80+
try:
81+
self.X_training = data_vstack((self.X_training, X))
82+
self.y_training = data_vstack((self.y_training, y))
83+
except ValueError:
84+
raise ValueError('the dimensions of the new training data and label must'
85+
'agree with the training data and labels provided so far')
86+
87+
def _fit_to_known(self, bootstrap: bool = False, **fit_kwargs) -> 'BaseLearner':
88+
"""
89+
Fits self.estimator to the training data and labels provided to it so far.
90+
91+
Args:
92+
bootstrap: If True, the method trains the model on a set bootstrapped from the known training instances.
93+
**fit_kwargs: Keyword arguments to be passed to the fit method of the predictor.
94+
95+
Returns:
96+
self
97+
"""
98+
if not bootstrap:
99+
self.estimator.fit(self.X_training, self.y_training, **fit_kwargs)
100+
else:
101+
n_instances = self.X_training.shape[0]
102+
bootstrap_idx = np.random.choice(range(n_instances), n_instances, replace=True)
103+
self.estimator.fit(self.X_training[bootstrap_idx], self.y_training[bootstrap_idx], **fit_kwargs)
104+
105+
return self
106+
107+
def _fit_on_new(self, X: modALinput, y: modALinput, bootstrap: bool = False, **fit_kwargs) -> 'BaseLearner':
108+
"""
109+
Fits self.estimator to the given data and labels.
110+
111+
Args:
112+
X: The new samples for which the labels are supplied by the expert.
113+
y: Labels corresponding to the new instances in X.
114+
bootstrap: If True, the method trains the model on a set bootstrapped from X.
115+
**fit_kwargs: Keyword arguments to be passed to the fit method of the predictor.
116+
117+
Returns:
118+
self
119+
"""
120+
check_X_y(X, y, accept_sparse=True, ensure_2d=False, allow_nd=True, multi_output=True)
121+
122+
if not bootstrap:
123+
self.estimator.fit(X, y, **fit_kwargs)
124+
else:
125+
bootstrap_idx = np.random.choice(range(X.shape[0]), X.shape[0], replace=True)
126+
self.estimator.fit(X[bootstrap_idx], y[bootstrap_idx])
127+
128+
return self
129+
130+
def fit(self, X: modALinput, y: modALinput, bootstrap: bool = False, **fit_kwargs) -> 'BaseLearner':
131+
"""
132+
Interface for the fit method of the predictor. Fits the predictor to the supplied data, then stores it
133+
internally for the active learning loop.
134+
135+
Args:
136+
X: The samples to be fitted.
137+
y: The corresponding labels.
138+
bootstrap: If true, trains the estimator on a set bootstrapped from X.
139+
Useful for building Committee models with bagging.
140+
**fit_kwargs: Keyword arguments to be passed to the fit method of the predictor.
141+
142+
Note:
143+
When using scikit-learn estimators, calling this method will make the ActiveLearner forget all training data
144+
it has seen!
145+
146+
Returns:
147+
self
148+
"""
149+
check_X_y(X, y, accept_sparse=True, ensure_2d=False, allow_nd=True, multi_output=True)
150+
self.X_training, self.y_training = X, y
151+
return self._fit_to_known(bootstrap=bootstrap, **fit_kwargs)
152+
153+
def predict(self, X: modALinput, **predict_kwargs) -> Any:
154+
"""
155+
Estimator predictions for X. Interface with the predict method of the estimator.
156+
157+
Args:
158+
X: The samples to be predicted.
159+
**predict_kwargs: Keyword arguments to be passed to the predict method of the estimator.
160+
161+
Returns:
162+
Estimator predictions for X.
163+
"""
164+
return self.estimator.predict(X, **predict_kwargs)
165+
166+
def predict_proba(self, X: modALinput, **predict_proba_kwargs) -> Any:
167+
"""
168+
Class probabilities if the predictor is a classifier. Interface with the predict_proba method of the classifier.
169+
170+
Args:
171+
X: The samples for which the class probabilities are to be predicted.
172+
**predict_proba_kwargs: Keyword arguments to be passed to the predict_proba method of the classifier.
173+
174+
Returns:
175+
Class probabilities for X.
176+
"""
177+
return self.estimator.predict_proba(X, **predict_proba_kwargs)
178+
179+
def query(self, *query_args, **query_kwargs) -> Union[Tuple, modALinput]:
180+
"""
181+
Finds the n_instances most informative point in the data provided by calling the query_strategy function.
182+
183+
Args:
184+
*query_args: The arguments for the query strategy. For instance, in the case of
185+
:func:`~modAL.uncertainty.uncertainty_sampling`, it is the pool of samples from which the query strategy
186+
should choose instances to request labels.
187+
**query_kwargs: Keyword arguments for the query strategy function.
188+
189+
Returns:
190+
Value of the query_strategy function. Should be the indices of the instances from the pool chosen to be
191+
labelled and the instances themselves. Can be different in other cases, for instance only the instance to be
192+
labelled upon query synthesis.
193+
"""
194+
query_result = self.query_strategy(self, *query_args, **query_kwargs)
195+
return query_result
196+
197+
def score(self, X: modALinput, y: modALinput, **score_kwargs) -> Any:
198+
"""
199+
Interface for the score method of the predictor.
200+
201+
Args:
202+
X: The samples for which prediction accuracy is to be calculated.
203+
y: Ground truth labels for X.
204+
**score_kwargs: Keyword arguments to be passed to the .score() method of the predictor.
205+
206+
Returns:
207+
The score of the predictor.
208+
"""
209+
return self.estimator.score(X, y, **score_kwargs)
210+
211+
@abc.abstractmethod
212+
def teach(self, *args, **kwargs) -> None:
213+
pass
214+
215+
216+
class BaseCommittee(ABC, BaseEstimator):
217+
"""
218+
Base class for query-by-committee setup.
219+
220+
Args:
221+
learner_list: List of ActiveLearner objects to form committee.
222+
query_strategy: Function to query labels.
223+
"""
224+
def __init__(self, learner_list: List[BaseLearner], query_strategy: Callable) -> None:
225+
assert type(learner_list) == list, 'learners must be supplied in a list'
226+
227+
self.learner_list = learner_list
228+
self.query_strategy = query_strategy
229+
230+
def __iter__(self) -> Iterator[BaseLearner]:
231+
for learner in self.learner_list:
232+
yield learner
233+
234+
def __len__(self) -> int:
235+
return len(self.learner_list)
236+
237+
def _add_training_data(self, X: modALinput, y: modALinput) -> None:
238+
"""
239+
Adds the new data and label to the known data for each learner, but does not retrain the model.
240+
241+
Args:
242+
X: The new samples for which the labels are supplied by the expert.
243+
y: Labels corresponding to the new instances in X.
244+
245+
Note:
246+
If the learners have been fitted, the features in X have to agree with the training samples which the
247+
classifier has seen.
248+
"""
249+
for learner in self.learner_list:
250+
learner._add_training_data(X, y)
251+
252+
def _fit_to_known(self, bootstrap: bool = False, **fit_kwargs) -> None:
253+
"""
254+
Fits all learners to the training data and labels provided to it so far.
255+
256+
Args:
257+
bootstrap: If True, each estimator is trained on a bootstrapped dataset. Useful when
258+
using bagging to build the ensemble.
259+
**fit_kwargs: Keyword arguments to be passed to the fit method of the predictor.
260+
"""
261+
for learner in self.learner_list:
262+
learner._fit_to_known(bootstrap=bootstrap, **fit_kwargs)
263+
264+
def _fit_on_new(self, X: modALinput, y: modALinput, bootstrap: bool = False, **fit_kwargs) -> None:
265+
"""
266+
Fits all learners to the given data and labels.
267+
268+
Args:
269+
X: The new samples for which the labels are supplied by the expert.
270+
y: Labels corresponding to the new instances in X.
271+
bootstrap: If True, the method trains the model on a set bootstrapped from X.
272+
**fit_kwargs: Keyword arguments to be passed to the fit method of the predictor.
273+
"""
274+
for learner in self.learner_list:
275+
learner._fit_on_new(X, y, bootstrap=bootstrap, **fit_kwargs)
276+
277+
def fit(self, X: modALinput, y: modALinput, **fit_kwargs) -> 'BaseCommittee':
278+
"""
279+
Fits every learner to a subset sampled with replacement from X. Calling this method makes the learner forget the
280+
data it has seen up until this point and replaces it with X! If you would like to perform bootstrapping on each
281+
learner using the data it has seen, use the method .rebag()!
282+
283+
Calling this method makes the learner forget the data it has seen up until this point and replaces it with X!
284+
285+
Args:
286+
X: The samples to be fitted on.
287+
y: The corresponding labels.
288+
**fit_kwargs: Keyword arguments to be passed to the fit method of the predictor.
289+
"""
290+
for learner in self.learner_list:
291+
learner.fit(X, y, **fit_kwargs)
292+
293+
return self
294+
295+
def query(self, *query_args, **query_kwargs) -> Union[Tuple, modALinput]:
296+
"""
297+
Finds the n_instances most informative point in the data provided by calling the query_strategy function.
298+
299+
Args:
300+
*query_args: The arguments for the query strategy. For instance, in the case of
301+
:func:`~modAL.disagreement.max_disagreement_sampling`, it is the pool of samples from which the query.
302+
strategy should choose instances to request labels.
303+
**query_kwargs: Keyword arguments for the query strategy function.
304+
305+
Returns:
306+
Return value of the query_strategy function. Should be the indices of the instances from the pool chosen to
307+
be labelled and the instances themselves. Can be different in other cases, for instance only the instance to
308+
be labelled upon query synthesis.
309+
"""
310+
query_result = self.query_strategy(self, *query_args, **query_kwargs)
311+
return query_result
312+
313+
def rebag(self, **fit_kwargs) -> None:
314+
"""
315+
Refits every learner with a dataset bootstrapped from its training instances. Contrary to .bag(), it bootstraps
316+
the training data for each learner based on its own examples.
317+
318+
Todo:
319+
Where is .bag()?
320+
321+
Args:
322+
**fit_kwargs: Keyword arguments to be passed to the fit method of the predictor.
323+
"""
324+
self._fit_to_known(bootstrap=True, **fit_kwargs)
325+
326+
def teach(self, X: modALinput, y: modALinput, bootstrap: bool = False, only_new: bool = False, **fit_kwargs) -> None:
327+
"""
328+
Adds X and y to the known training data for each learner and retrains learners with the augmented dataset.
329+
330+
Args:
331+
X: The new samples for which the labels are supplied by the expert.
332+
y: Labels corresponding to the new instances in X.
333+
bootstrap: If True, trains each learner on a bootstrapped set. Useful when building the ensemble by bagging.
334+
only_new: If True, the model is retrained using only X and y, ignoring the previously provided examples.
335+
**fit_kwargs: Keyword arguments to be passed to the fit method of the predictor.
336+
"""
337+
self._add_training_data(X, y)
338+
if not only_new:
339+
self._fit_to_known(bootstrap=bootstrap, **fit_kwargs)
340+
else:
341+
self._fit_on_new(X, y, bootstrap=bootstrap, **fit_kwargs)
342+
343+
@abc.abstractmethod
344+
def predict(self, X: modALinput) -> Any:
345+
pass
346+
347+
@abc.abstractmethod
348+
def vote(self, X: modALinput) -> Any: # TODO: clarify typing
349+
pass

0 commit comments

Comments
 (0)