@@ -174,22 +174,22 @@ def test_check_and_transform_label_format(self):
174174
175175 # test input shape (nb_samples,)
176176 labels = np .array ([3 , 1 , 4 ])
177- labels_transformed = check_and_transform_label_format (labels )
177+ labels_transformed = check_and_transform_label_format (labels , 5 )
178178 np .testing .assert_array_equal (labels_transformed , labels_expected )
179179
180180 # test input shape (nb_samples, 1)
181181 labels = np .array ([[3 ], [1 ], [4 ]])
182- labels_transformed = check_and_transform_label_format (labels )
182+ labels_transformed = check_and_transform_label_format (labels , 5 )
183183 np .testing .assert_array_equal (labels_transformed , labels_expected )
184184
185185 # test input shape (nb_samples, nb_classes)
186186 labels = np .array ([[0 , 0 , 0 , 1 , 0 ], [0 , 1 , 0 , 0 , 0 ], [0 , 0 , 0 , 0 , 1 ]])
187- labels_transformed = check_and_transform_label_format (labels )
187+ labels_transformed = check_and_transform_label_format (labels , 5 )
188188 np .testing .assert_array_equal (labels_transformed , labels_expected )
189189
190190 # test input shape (nb_samples, nb_classes) with return_one_hot=False
191191 labels = np .array ([[0 , 0 , 0 , 1 , 0 ], [0 , 1 , 0 , 0 , 0 ], [0 , 0 , 0 , 0 , 1 ]])
192- labels_transformed = check_and_transform_label_format (labels , return_one_hot = False )
192+ labels_transformed = check_and_transform_label_format (labels , 5 , return_one_hot = False )
193193 np .testing .assert_array_equal (labels_transformed , np .argmax (labels_expected , axis = 1 ))
194194
195195 # ValueError for len(labels.shape) > 2
0 commit comments