@@ -171,26 +171,44 @@ def test_to_categorical(self):
171171
172172 def test_check_and_transform_label_format (self ):
173173 labels_expected = np .array ([[0 , 0 , 0 , 1 , 0 ], [0 , 1 , 0 , 0 , 0 ], [0 , 0 , 0 , 0 , 1 ]])
174+ labels_expected_binary = np .array ([[0 , 1 ], [1 , 0 ], [0 , 1 ]])
174175
175176 # test input shape (nb_samples,)
176177 labels = np .array ([3 , 1 , 4 ])
177- labels_transformed = check_and_transform_label_format (labels , 5 )
178+ labels_transformed = check_and_transform_label_format (labels , nb_classes = 5 , return_one_hot = True )
178179 np .testing .assert_array_equal (labels_transformed , labels_expected )
179180
180181 # test input shape (nb_samples, 1)
181182 labels = np .array ([[3 ], [1 ], [4 ]])
182- labels_transformed = check_and_transform_label_format (labels , 5 )
183+ labels_transformed = check_and_transform_label_format (labels , nb_classes = 5 , return_one_hot = True )
183184 np .testing .assert_array_equal (labels_transformed , labels_expected )
184185
186+ # test input shape (nb_samples, 1) - binary
187+ labels = np .array ([[1 ], [0 ], [1 ]])
188+ labels_transformed = check_and_transform_label_format (labels , nb_classes = 2 , return_one_hot = True )
189+ np .testing .assert_array_equal (labels_transformed , labels_expected_binary )
190+
191+ # test input shape (nb_samples, 1) - binary
192+ labels = np .array ([[0 , 1 ], [1 , 0 ], [0 , 1 ]])
193+ labels_transformed = check_and_transform_label_format (labels , nb_classes = 2 , return_one_hot = True )
194+ np .testing .assert_array_equal (labels_transformed , labels_expected_binary )
195+
185196 # test input shape (nb_samples, nb_classes)
186197 labels = np .array ([[0 , 0 , 0 , 1 , 0 ], [0 , 1 , 0 , 0 , 0 ], [0 , 0 , 0 , 0 , 1 ]])
187- labels_transformed = check_and_transform_label_format (labels , 5 )
198+ labels_transformed = check_and_transform_label_format (labels , nb_classes = 5 , return_one_hot = True )
188199 np .testing .assert_array_equal (labels_transformed , labels_expected )
189200
190201 # test input shape (nb_samples, nb_classes) with return_one_hot=False
191202 labels = np .array ([[0 , 0 , 0 , 1 , 0 ], [0 , 1 , 0 , 0 , 0 ], [0 , 0 , 0 , 0 , 1 ]])
192- labels_transformed = check_and_transform_label_format (labels , 5 , return_one_hot = False )
193- np .testing .assert_array_equal (labels_transformed , np .argmax (labels_expected , axis = 1 ))
203+ labels_transformed = check_and_transform_label_format (labels , nb_classes = 5 , return_one_hot = False )
204+ np .testing .assert_array_equal (labels_transformed , np .expand_dims (np .argmax (labels_expected , axis = 1 ), axis = 1 ))
205+
206+ # test input shape (nb_samples, 1) - binary
207+ labels = np .array ([[1 ], [0 ], [1 ]])
208+ labels_transformed = check_and_transform_label_format (labels , nb_classes = 2 , return_one_hot = False )
209+ np .testing .assert_array_equal (
210+ labels_transformed , np .expand_dims (np .argmax (labels_expected_binary , axis = 1 ), axis = 1 )
211+ )
194212
195213 # ValueError for len(labels.shape) > 2
196214 labels = np .array ([[[0 , 0 , 0 , 1 , 0 ], [0 , 1 , 0 , 0 , 0 ], [0 , 0 , 0 , 0 , 1 ]]])
0 commit comments