Skip to content

Commit b88c7fb

Browse files
committed
fix: ActiveLearner.__init__ added
1 parent 0d3ffae commit b88c7fb

File tree

1 file changed

+46
-7
lines changed

1 file changed

+46
-7
lines changed

modAL/models.py

Lines changed: 46 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""
2-
==========================================
32
Core models for active learning algorithms
4-
==========================================
3+
------------------------------------------
54
"""
65

76
import abc
@@ -49,7 +48,7 @@ class BaseLearner(ABC, BaseEstimator):
4948
"""
5049
def __init__(self,
5150
estimator: BaseEstimator,
52-
query_strategy: Callable = uncertainty_sampling,
51+
query_strategy: Callable,
5352
X_training: Optional[modALinput] = None,
5453
y_training: Optional[modALinput] = None,
5554
bootstrap_init: bool = False,
@@ -223,6 +222,23 @@ class ActiveLearner(BaseLearner):
223222
"""
224223
This class is an abstract model of a general active learning algorithm.
225224
225+
Args:
226+
estimator: The estimator to be used in the active learning loop.
227+
query_strategy: Function providing the query strategy for the active learning loop,
228+
for instance, modAL.uncertainty.uncertainty_sampling.
229+
X_training: Initial training samples, if available.
230+
y_training: Initial training labels corresponding to initial training samples.
231+
bootstrap_init: If initial training data is available, bootstrapping can be done during the first training.
232+
Useful when building Committee models with bagging.
233+
**fit_kwargs: keyword arguments.
234+
235+
Attributes:
236+
estimator: The estimator to be used in the active learning loop.
237+
query_strategy: Function providing the query strategy for the active learning loop.
238+
X_training: If the model hasn't been fitted yet it is None, otherwise it contains the samples
239+
which the model has been trained on.
240+
y_training: The labels corresponding to X_training.
241+
226242
Examples:
227243
228244
>>> from sklearn.datasets import load_iris
@@ -249,9 +265,19 @@ class ActiveLearner(BaseLearner):
249265
... X=iris['data'][query_idx].reshape(1, -1),
250266
... y=iris['target'][query_idx].reshape(1, )
251267
... )
252-
253268
"""
254269

270+
def __init__(self,
271+
estimator: BaseEstimator,
272+
query_strategy: Callable = uncertainty_sampling,
273+
X_training: Optional[modALinput] = None,
274+
y_training: Optional[modALinput] = None,
275+
bootstrap_init: bool = False,
276+
**fit_kwargs
277+
) -> None:
278+
super().__init__(estimator, query_strategy,
279+
X_training, y_training, bootstrap_init, **fit_kwargs)
280+
255281
def teach(self, X: modALinput, y: modALinput, bootstrap: bool = False, only_new: bool = False, **fit_kwargs) -> None:
256282
"""
257283
Adds X and y to the known training data and retrains the predictor with the augmented dataset.
@@ -277,7 +303,23 @@ class BayesianOptimizer(BaseLearner):
277303
"""
278304
This class is an abstract model of a Bayesian optimizer algorithm.
279305
306+
Args:
307+
estimator: The estimator to be used in the Bayesian optimization. (For instance, a
308+
GaussianProcessRegressor.)
309+
query_strategy: Function providing the query strategy for Bayesian optimization,
310+
for instance, modAL.acquisitions.max_EI.
311+
X_training: Initial training samples, if available.
312+
y_training: Initial training labels corresponding to initial training samples.
313+
bootstrap_init: If initial training data is available, bootstrapping can be done during the first training.
314+
Useful when building Committee models with bagging.
315+
**fit_kwargs: keyword arguments.
316+
280317
Attributes:
318+
estimator: The estimator to be used in the Bayesian optimization.
319+
query_strategy: Function providing the query strategy for Bayesian optimization.
320+
X_training: If the model hasn't been fitted yet it is None, otherwise it contains the samples
321+
which the model has been trained on.
322+
y_training: The labels corresponding to X_training.
281323
X_max: argmax of the function so far.
282324
y_max: Max of the function so far.
283325
@@ -322,7 +364,6 @@ class BayesianOptimizer(BaseLearner):
322364
... # query
323365
... query_idx, query_inst = optimizer.query(X)
324366
... optimizer.teach(X[query_idx].reshape(1, -1), y[query_idx].reshape(1, -1))
325-
326367
"""
327368
def __init__(self,
328369
estimator: BaseEstimator,
@@ -566,7 +607,6 @@ class Committee(BaseCommittee):
566607
... X=iris['data'][query_idx].reshape(1, -1),
567608
... y=iris['target'][query_idx].reshape(1, )
568609
... )
569-
570610
"""
571611
def __init__(self, learner_list: List[ActiveLearner], query_strategy: Callable = vote_entropy_sampling) -> None:
572612
super().__init__(learner_list, query_strategy)
@@ -750,7 +790,6 @@ class CommitteeRegressor(BaseCommittee):
750790
>>> for idx in range(n_queries):
751791
... query_idx, query_instance = committee.query(X.reshape(-1, 1))
752792
... committee.teach(X[query_idx].reshape(-1, 1), y[query_idx].reshape(-1, 1))
753-
754793
"""
755794
def __init__(self, learner_list: List[ActiveLearner], query_strategy: Callable = max_std_sampling) -> None:
756795
super().__init__(learner_list, query_strategy)

0 commit comments

Comments
 (0)