Skip to content

Commit 3bf9b55

Browse files
author
TrojAISec
committed
add nb_classes to test_check_and_transform_label_format
1 parent a480e17 commit 3bf9b55

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

tests/test_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -174,22 +174,22 @@ def test_check_and_transform_label_format(self):
174174

175175
# test input shape (nb_samples,)
176176
labels = np.array([3, 1, 4])
177-
labels_transformed = check_and_transform_label_format(labels)
177+
labels_transformed = check_and_transform_label_format(labels, 5)
178178
np.testing.assert_array_equal(labels_transformed, labels_expected)
179179

180180
# test input shape (nb_samples, 1)
181181
labels = np.array([[3], [1], [4]])
182-
labels_transformed = check_and_transform_label_format(labels)
182+
labels_transformed = check_and_transform_label_format(labels, 5)
183183
np.testing.assert_array_equal(labels_transformed, labels_expected)
184184

185185
# test input shape (nb_samples, nb_classes)
186186
labels = np.array([[0, 0, 0, 1, 0], [0, 1, 0, 0, 0], [0, 0, 0, 0, 1]])
187-
labels_transformed = check_and_transform_label_format(labels)
187+
labels_transformed = check_and_transform_label_format(labels, 5)
188188
np.testing.assert_array_equal(labels_transformed, labels_expected)
189189

190190
# test input shape (nb_samples, nb_classes) with return_one_hot=False
191191
labels = np.array([[0, 0, 0, 1, 0], [0, 1, 0, 0, 0], [0, 0, 0, 0, 1]])
192-
labels_transformed = check_and_transform_label_format(labels, return_one_hot=False)
192+
labels_transformed = check_and_transform_label_format(labels, 5, return_one_hot=False)
193193
np.testing.assert_array_equal(labels_transformed, np.argmax(labels_expected, axis=1))
194194

195195
# ValueError for len(labels.shape) > 2

0 commit comments

Comments
 (0)