Skip to content

Commit 65d9efe

Browse files
committed
modAL.models.ActiveLearner teach method extended, now it is possible to train only on new data
1 parent 66b8f5b commit 65d9efe

File tree

3 files changed

+29
-5
lines changed

3 files changed

+29
-5
lines changed

modAL/cluster.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99

1010

1111
class HierarchicalClustering:
12-
def __init__(self, X, labels):
13-
self.labels = labels
14-
self.cluster = AgglomerativeClustering(n_clusters=2, compute_full_tree=True)
12+
def __init__(self, X, classes, n_batch=1):
13+
self.classes = classes
14+
self.cluster = AgglomerativeClustering(compute_full_tree=True)
1515
self.cluster.fit(X)
1616

1717
def __call__(self, *args, **kwargs):

modAL/models.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,7 @@ def score(self, X, y, **score_kwargs):
320320
"""
321321
return self.estimator.score(X, y, **score_kwargs)
322322

323-
def teach(self, X, y, bootstrap=False, **fit_kwargs):
323+
def teach(self, X, y, bootstrap=False, only_new=False, **fit_kwargs):
324324
"""
325325
Adds X and y to the known training data and retrains the predictor
326326
with the augmented dataset.
@@ -338,12 +338,20 @@ def teach(self, X, y, bootstrap=False, **fit_kwargs):
338338
If True, training is done on a bootstrapped dataset. Useful for building
339339
Committee models with bagging.
340340
341+
only_new: boolean
342+
If True, the model is retrained using only X and y, ignoring the previously
343+
provided examples. Useful when working with models where the .fit() method
344+
doesn't retrain the model from scratch. (For example, in tensorflow or keras.)
345+
341346
fit_kwargs: keyword arguments
342347
Keyword arguments to be passed to the fit method
343348
of the predictor.
344349
"""
345350
self._add_training_data(X, y)
346-
self._fit_to_known(bootstrap=bootstrap, **fit_kwargs)
351+
if not only_new:
352+
self._fit_to_known(bootstrap=bootstrap, **fit_kwargs)
353+
else:
354+
self._fit_on_new(X, y, bootstrap=bootstrap, **fit_kwargs)
347355

348356

349357
class BaseCommittee(ABC, BaseEstimator):

tests/core_tests.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,22 @@ def test_score(self):
380380
score_return
381381
)
382382

383+
def test_teach(self):
384+
X_training = np.random.rand(10, 2)
385+
y_training = np.random.randint(0, 2, size=10)
386+
387+
for n_samples in range(1, 10):
388+
X = np.random.rand(n_samples, 2)
389+
y = np.random.randint(0, 2, size=n_samples)
390+
391+
learner = modAL.models.ActiveLearner(
392+
X_training=X_training, y_training=y_training,
393+
estimator=mock.MockClassifier()
394+
)
395+
396+
learner.teach(X, y, only_new=False)
397+
learner.teach(X, y, only_new=True)
398+
383399
def test_keras(self):
384400
pass
385401

0 commit comments

Comments
 (0)