@@ -159,6 +159,40 @@ def test_data_vstack(self):
159159 # not supported formats
160160 self .assertRaises (TypeError , modAL .utils .data .data_vstack , (1 , 1 ))
161161
162+ # functions from modAL.utils.selection
163+
164+ def test_multi_argmax (self ):
165+ for n_pool in range (2 , 100 ):
166+ for n_instances in range (1 , n_pool ):
167+ utility = np .zeros (n_pool )
168+ max_idx = np .random .choice (range (n_pool ), size = n_instances , replace = False )
169+ utility [max_idx ] = 1e-10 + np .random .rand (n_instances , )
170+ np .testing .assert_equal (
171+ np .sort (modAL .utils .selection .multi_argmax (utility , n_instances )),
172+ np .sort (max_idx )
173+ )
174+
175+ def test_shuffled_argmax (self ):
176+ for n_pool in range (1 , 100 ):
177+ for n_instances in range (1 , n_pool + 1 ):
178+ values = np .random .permutation (n_pool )
179+ true_query_idx = np .argsort (values )[:n_instances ]
180+
181+ np .testing .assert_equal (
182+ true_query_idx ,
183+ modAL .utils .selection .shuffled_argmax (values , n_instances )
184+ )
185+
186+ def test_weighted_random (self ):
187+ for n_pool in range (2 , 100 ):
188+ for n_instances in range (1 , n_pool ):
189+ utility = np .ones (n_pool )
190+ query_idx = modAL .utils .selection .weighted_random (utility , n_instances )
191+ # testing for correct number of returned indices
192+ np .testing .assert_equal (len (query_idx ), n_instances )
193+ # testing for uniqueness of each query index
194+ np .testing .assert_equal (len (query_idx ), len (np .unique (query_idx )))
195+
162196
163197class TestAcquisitionFunctions (unittest .TestCase ):
164198 def test_acquisition_functions (self ):
@@ -524,30 +558,6 @@ def test_entropy_sampling(self):
524558 np .testing .assert_array_equal (query_idx , true_query_idx )
525559
526560
527- class TestQueries (unittest .TestCase ):
528-
529- def test_multi_argmax (self ):
530- for n_pool in range (2 , 100 ):
531- for n_instances in range (1 , n_pool ):
532- utility = np .zeros (n_pool )
533- max_idx = np .random .choice (range (n_pool ), size = n_instances , replace = False )
534- utility [max_idx ] = 1e-10 + np .random .rand (n_instances , )
535- np .testing .assert_equal (
536- np .sort (modAL .utils .selection .multi_argmax (utility , n_instances )),
537- np .sort (max_idx )
538- )
539-
540- def test_weighted_random (self ):
541- for n_pool in range (2 , 100 ):
542- for n_instances in range (1 , n_pool ):
543- utility = np .ones (n_pool )
544- query_idx = modAL .utils .selection .weighted_random (utility , n_instances )
545- # testing for correct number of returned indices
546- np .testing .assert_equal (len (query_idx ), n_instances )
547- # testing for uniqueness of each query index
548- np .testing .assert_equal (len (query_idx ), len (np .unique (query_idx )))
549-
550-
551561class TestActiveLearner (unittest .TestCase ):
552562
553563 def test_add_training_data (self ):
0 commit comments