Skip to content

Commit 7cb213b

Browse files
committed
add: tests for multilabel strategies added
1 parent af6bb32 commit 7cb213b

File tree

1 file changed

+7
-15
lines changed

1 file changed

+7
-15
lines changed

tests/core_tests.py

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -942,29 +942,21 @@ def test_SVM_loss(self):
942942
self.assertEqual(avg_loss.shape, (len(X_pool), ))
943943
self.assertEqual(mcc_loss.shape, (len(X_pool),))
944944

945-
def test_mean_max_loss(self):
945+
def test_strategies(self):
946946
for n_classes in range(2, 10):
947947
for n_pool_instances in range(1, 10):
948948
for n_query_instances in range(1, min(n_pool_instances, 3)):
949949
X_training = np.random.rand(n_pool_instances, 5)
950950
y_training = np.random.randint(0, 2, size=(n_pool_instances, n_classes))
951951
X_pool = np.random.rand(n_pool_instances, 5)
952-
y_pool = np.random.randint(0, 2, size=(n_pool_instances, n_classes))
953-
classifier = OneVsRestClassifier(SVC())
954-
classifier.fit(X_training, y_training)
955-
query_idx, query_inst = modAL.multilabel.mean_max_loss(classifier, X_pool, n_query_instances)
956-
957-
def test_max_loss(self):
958-
for n_classes in range(2, 10):
959-
for n_pool_instances in range(1, 10):
960-
for n_query_instances in range(1, min(n_pool_instances, 3)):
961-
X_training = np.random.rand(n_pool_instances, 5)
962-
y_training = np.random.randint(0, 2, size=(n_pool_instances, n_classes))
963-
X_pool = np.random.rand(n_pool_instances, 5)
964-
y_pool = np.random.randint(0, 2, size=(n_pool_instances, n_classes))
965952
classifier = OneVsRestClassifier(SVC(probability=True))
966953
classifier.fit(X_training, y_training)
967-
query_idx, query_inst = modAL.multilabel.max_loss(classifier, X_pool, n_query_instances)
954+
modAL.multilabel.mean_max_loss(classifier, X_pool, n_query_instances)
955+
modAL.multilabel.max_loss(classifier, X_pool, n_query_instances)
956+
modAL.multilabel.min_confidence(classifier, X_pool, n_query_instances)
957+
modAL.multilabel.avg_confidence(classifier, X_pool, n_query_instances)
958+
modAL.multilabel.max_score(classifier, X_pool, n_query_instances)
959+
modAL.multilabel.avg_score(classifier, X_pool, n_query_instances)
968960

969961

970962
class TestExamples(unittest.TestCase):

0 commit comments

Comments
 (0)