Skip to content

Commit 171e2e9

Browse files
author
Boyan Hristov
committed
#20, #104 - added tests for on_transformed functionality and pandas support; small fixes
1 parent 1ad79fe commit 171e2e9

File tree

3 files changed

+188
-7
lines changed

3 files changed

+188
-7
lines changed

modAL/models/base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,8 @@ def __init__(self, learner_list: List[BaseLearner], query_strategy: Callable, on
300300
self.learner_list = learner_list
301301
self.query_strategy = query_strategy
302302
self.on_transformed = on_transformed
303+
# TODO: update training data when using fit() and teach() methods
304+
self.X_training = None
303305

304306
def __iter__(self) -> Iterator[BaseLearner]:
305307
for learner in self.learner_list:

modAL/models/learners.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from modAL.models.base import BaseLearner, BaseCommittee
99
from modAL.utils.validation import check_class_labels, check_class_proba
10-
from modAL.utils.data import modALinput
10+
from modAL.utils.data import modALinput, retrieve_rows
1111
from modAL.uncertainty import uncertainty_sampling
1212
from modAL.disagreement import vote_entropy_sampling, max_std_sampling
1313
from modAL.acquisition import max_EI
@@ -187,7 +187,7 @@ def __init__(self,
187187
# setting the maximum value
188188
if self.y_training is not None:
189189
max_idx = np.argmax(self.y_training)
190-
self.X_max = self.X_training[max_idx]
190+
self.X_max = retrieve_rows(self.X_training, max_idx)
191191
self.y_max = self.y_training[max_idx]
192192
else:
193193
self.X_max = None
@@ -198,7 +198,7 @@ def _set_max(self, X: modALinput, y: modALinput) -> None:
198198
y_max = y[max_idx]
199199
if y_max > self.y_max:
200200
self.y_max = y_max
201-
self.X_max = X[max_idx]
201+
self.X_max = retrieve_rows(X, max_idx)
202202

203203
def get_max(self) -> Tuple:
204204
"""
@@ -248,6 +248,8 @@ class Committee(BaseCommittee):
248248
learner_list: A list of ActiveLearners forming the Committee.
249249
query_strategy: Query strategy function. Committee supports disagreement-based query strategies from
250250
:mod:`modAL.disagreement`, but uncertainty-based ones from :mod:`modAL.uncertainty` are also supported.
251+
on_transformed: Whether to transform samples with the pipeline defined by each learner's estimator
252+
when applying the query strategy.
251253
252254
Attributes:
253255
classes_: Class labels known by the Committee.
@@ -288,8 +290,9 @@ class Committee(BaseCommittee):
288290
... y=iris['target'][query_idx].reshape(1, )
289291
... )
290292
"""
291-
def __init__(self, learner_list: List[ActiveLearner], query_strategy: Callable = vote_entropy_sampling) -> None:
292-
super().__init__(learner_list, query_strategy)
293+
def __init__(self, learner_list: List[ActiveLearner], query_strategy: Callable = vote_entropy_sampling,
294+
on_transformed: bool = False) -> None:
295+
super().__init__(learner_list, query_strategy, on_transformed)
293296
self._set_classes()
294297

295298
def _set_classes(self):
@@ -456,6 +459,8 @@ class CommitteeRegressor(BaseCommittee):
456459
Args:
457460
learner_list: A list of ActiveLearners forming the CommitteeRegressor.
458461
query_strategy: Query strategy function.
462+
on_transformed: Whether to transform samples with the pipeline defined by each learner's estimator
463+
when applying the query strategy.
459464
460465
Examples:
461466
@@ -499,8 +504,9 @@ class CommitteeRegressor(BaseCommittee):
499504
... query_idx, query_instance = committee.query(X.reshape(-1, 1))
500505
... committee.teach(X[query_idx].reshape(-1, 1), y[query_idx].reshape(-1, 1))
501506
"""
502-
def __init__(self, learner_list: List[ActiveLearner], query_strategy: Callable = max_std_sampling) -> None:
503-
super().__init__(learner_list, query_strategy)
507+
def __init__(self, learner_list: List[ActiveLearner], query_strategy: Callable = max_std_sampling,
508+
on_transformed: bool = False) -> None:
509+
super().__init__(learner_list, query_strategy, on_transformed)
504510

505511
def predict(self, X: modALinput, return_std: bool = False, **predict_kwargs) -> Any:
506512
"""

tests/core_tests.py

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import random
22
import unittest
33
import numpy as np
4+
import pandas as pd
45

56
import mock
67
import modAL.models.base
@@ -26,6 +27,8 @@
2627
from sklearn.metrics import confusion_matrix
2728
from sklearn.svm import SVC
2829
from sklearn.multiclass import OneVsRestClassifier
30+
from sklearn.pipeline import make_pipeline
31+
from sklearn.preprocessing import FunctionTransformer
2932
from scipy.stats import entropy, norm
3033
from scipy.special import ndtr
3134
from scipy import sparse as sp
@@ -788,6 +791,68 @@ def test_sparse_matrices(self):
788791
query_idx, query_inst = learner.query(X_pool)
789792
learner.teach(X_pool[query_idx], y_pool[query_idx])
790793

794+
def test_on_transformed(self):
795+
n_samples = 10
796+
n_features = 5
797+
query_strategies = [
798+
modAL.batch.uncertainty_batch_sampling
799+
# add further strategies which work with instance representations
800+
# no further ones as of 25.09.2020
801+
]
802+
X_pool = np.random.rand(n_samples, n_features)
803+
804+
# use pandas data frame as X_pool, which will be transformed back to numpy with sklearn pipeline
805+
X_pool = pd.DataFrame(X_pool)
806+
807+
y_pool = np.random.randint(0, 2, size=(n_samples,))
808+
train_idx = np.random.choice(range(n_samples), size=2, replace=False)
809+
810+
for query_strategy in query_strategies:
811+
learner = modAL.models.learners.ActiveLearner(
812+
estimator=make_pipeline(
813+
FunctionTransformer(func=pd.DataFrame.to_numpy),
814+
RandomForestClassifier(n_estimators=10)
815+
),
816+
query_strategy=query_strategy,
817+
X_training=X_pool.iloc[train_idx],
818+
y_training=y_pool[train_idx],
819+
on_transformed=True
820+
)
821+
query_idx, query_inst = learner.query(X_pool)
822+
learner.teach(X_pool.iloc[query_idx], y_pool[query_idx])
823+
824+
def test_old_query_strategy_interface(self):
825+
n_samples = 10
826+
n_features = 5
827+
X_pool = np.random.rand(n_samples, n_features)
828+
y_pool = np.random.randint(0, 2, size=(n_samples,))
829+
830+
# defining a custom query strategy also returning the selected instance
831+
# make sure even if a query strategy works in some funny way
832+
# (e.g. instance not matching instance index),
833+
# the old interface remains unchanged
834+
query_idx_ = np.random.choice(n_samples, 2)
835+
query_instance_ = X_pool[(query_idx_ + 1) % len(X_pool)]
836+
837+
def custom_query_strategy(classifier, X):
838+
return query_idx_, query_instance_
839+
840+
841+
train_idx = np.random.choice(range(n_samples), size=2, replace=False)
842+
custom_query_learner = modAL.models.learners.ActiveLearner(
843+
estimator=RandomForestClassifier(n_estimators=10),
844+
query_strategy=custom_query_strategy,
845+
X_training=X_pool[train_idx], y_training=y_pool[train_idx]
846+
)
847+
848+
query_idx, query_instance = custom_query_learner.query(X_pool)
849+
custom_query_learner.teach(
850+
X=X_pool[query_idx],
851+
y=y_pool[query_idx]
852+
)
853+
np.testing.assert_equal(query_idx, query_idx_)
854+
np.testing.assert_equal(query_instance, query_instance_)
855+
791856

792857
class TestBayesianOptimizer(unittest.TestCase):
793858
def test_set_max(self):
@@ -897,6 +962,39 @@ def test_teach(self):
897962
)
898963
learner.teach(X, y, bootstrap=bootstrap, only_new=only_new)
899964

965+
def test_on_transformed(self):
966+
n_samples = 10
967+
n_features = 5
968+
query_strategies = [
969+
# TODO remove, added just to make sure on_transformed doesn't break anything
970+
# but it has no influence on this strategy, nothing special tested here
971+
mock.MockFunction(return_val=[np.random.randint(0, n_samples)])
972+
973+
# add further strategies which work with instance representations
974+
# no further ones as of 25.09.2020
975+
]
976+
X_pool = np.random.rand(n_samples, n_features)
977+
978+
# use pandas data frame as X_pool, which will be transformed back to numpy with sklearn pipeline
979+
X_pool = pd.DataFrame(X_pool)
980+
981+
y_pool = np.random.rand(n_samples)
982+
train_idx = np.random.choice(range(n_samples), size=2, replace=False)
983+
984+
for query_strategy in query_strategies:
985+
learner = modAL.models.learners.BayesianOptimizer(
986+
estimator=make_pipeline(
987+
FunctionTransformer(func=pd.DataFrame.to_numpy),
988+
GaussianProcessRegressor()
989+
),
990+
query_strategy=query_strategy,
991+
X_training=X_pool.iloc[train_idx],
992+
y_training=y_pool[train_idx],
993+
on_transformed=True
994+
)
995+
query_idx, query_inst = learner.query(X_pool)
996+
learner.teach(X_pool.iloc[query_idx], y_pool[query_idx])
997+
900998

901999
class TestCommittee(unittest.TestCase):
9021000

@@ -1007,6 +1105,42 @@ def test_teach(self):
10071105

10081106
committee.teach(X, y, bootstrap=bootstrap, only_new=only_new)
10091107

1108+
def test_on_transformed(self):
1109+
n_samples = 10
1110+
n_features = 5
1111+
query_strategies = [
1112+
modAL.batch.uncertainty_batch_sampling
1113+
# add further strategies which work with instance representations
1114+
# no further ones as of 25.09.2020
1115+
]
1116+
X_pool = np.random.rand(n_samples, n_features)
1117+
1118+
# use pandas data frame as X_pool, which will be transformed back to numpy with sklearn pipeline
1119+
X_pool = pd.DataFrame(X_pool)
1120+
1121+
y_pool = np.random.randint(0, 2, size=(n_samples,))
1122+
train_idx = np.random.choice(range(n_samples), size=5, replace=False)
1123+
1124+
learner_list = [modAL.models.learners.ActiveLearner(
1125+
estimator=make_pipeline(
1126+
FunctionTransformer(func=pd.DataFrame.to_numpy),
1127+
RandomForestClassifier(n_estimators=10)
1128+
),
1129+
# committee learners can contain different amounts of
1130+
# different instances
1131+
X_training=X_pool.iloc[train_idx[(np.arange(i + 1) + i) % len(train_idx)]],
1132+
y_training=y_pool[train_idx[(np.arange(i + 1) + i) % len(train_idx)]],
1133+
) for i in range(3)]
1134+
1135+
for query_strategy in query_strategies:
1136+
committee = modAL.models.learners.Committee(
1137+
learner_list=learner_list,
1138+
query_strategy=query_strategy,
1139+
on_transformed=True
1140+
)
1141+
query_idx, query_inst = committee.query(X_pool)
1142+
committee.teach(X_pool.iloc[query_idx], y_pool[query_idx])
1143+
10101144

10111145
class TestCommitteeRegressor(unittest.TestCase):
10121146

@@ -1040,6 +1174,45 @@ def test_vote(self):
10401174
vote_output
10411175
)
10421176

1177+
def test_on_transformed(self):
1178+
n_samples = 10
1179+
n_features = 5
1180+
query_strategies = [
1181+
# TODO remove, added just to make sure on_transformed doesn't break anything
1182+
# but it has no influence on this strategy, nothing special tested here
1183+
mock.MockFunction(return_val=[np.random.randint(0, n_samples)])
1184+
1185+
# add further strategies which work with instance representations
1186+
# no further ones as of 25.09.2020
1187+
]
1188+
X_pool = np.random.rand(n_samples, n_features)
1189+
1190+
# use pandas data frame as X_pool, which will be transformed back to numpy with sklearn pipeline
1191+
X_pool = pd.DataFrame(X_pool)
1192+
1193+
y_pool = np.random.rand(n_samples)
1194+
train_idx = np.random.choice(range(n_samples), size=2, replace=False)
1195+
1196+
learner_list = [modAL.models.learners.ActiveLearner(
1197+
estimator=make_pipeline(
1198+
FunctionTransformer(func=pd.DataFrame.to_numpy),
1199+
GaussianProcessRegressor()
1200+
),
1201+
# committee learners can contain different amounts of
1202+
# different instances
1203+
X_training=X_pool.iloc[train_idx[(np.arange(i + 1) + i) % len(train_idx)]],
1204+
y_training=y_pool[train_idx[(np.arange(i + 1) + i) % len(train_idx)]],
1205+
) for i in range(3)]
1206+
1207+
for query_strategy in query_strategies:
1208+
committee = modAL.models.learners.CommitteeRegressor(
1209+
learner_list=learner_list,
1210+
query_strategy=query_strategy,
1211+
on_transformed=True
1212+
)
1213+
query_idx, query_inst = committee.query(X_pool)
1214+
committee.teach(X_pool.iloc[query_idx], y_pool[query_idx])
1215+
10431216

10441217
class TestMultilabel(unittest.TestCase):
10451218
def test_SVM_loss(self):

0 commit comments

Comments
 (0)