Skip to content

Commit a7c1f7b

Browse files
authored
Merge pull request #1443 from Trusted-AI/development_issue_1334
Update check_and_transform_label_format for index labels
2 parents df2e613 + e594795 commit a7c1f7b

File tree

4 files changed

+44
-17
lines changed

4 files changed

+44
-17
lines changed

art/attacks/evasion/pixel_threshold.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,9 +156,11 @@ def generate(self, x: np.ndarray, y: Optional[np.ndarray] = None, **kwargs) -> n
156156
raise ValueError(
157157
"This attack has not yet been tested for binary classification with a single output classifier."
158158
)
159-
if len(y.shape) > 1:
159+
if y.ndim > 1 and y.shape[1] > 1:
160160
y = np.argmax(y, axis=1)
161161

162+
y = np.squeeze(y)
163+
162164
if self.th is None:
163165
logger.info(
164166
"Performing minimal perturbation Attack. \

art/estimators/classification/keras.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -562,13 +562,14 @@ def fit(self, x: np.ndarray, y: np.ndarray, batch_size: int = 128, nb_epochs: in
562562
`fit_generator` function in Keras and will be passed to this function as such. Including the number of
563563
epochs or the number of steps per epoch as part of this argument will result in as error.
564564
"""
565+
y_ndim = y.ndim
565566
y = check_and_transform_label_format(y, self.nb_classes)
566567

567568
# Apply preprocessing
568569
x_preprocessed, y_preprocessed = self._apply_preprocessing(x, y, fit=True)
569570

570571
# Adjust the shape of y for loss functions that do not take labels in one-hot encoding
571-
if self._reduce_labels:
572+
if self._reduce_labels or y_ndim == 1:
572573
y_preprocessed = np.argmax(y_preprocessed, axis=1)
573574

574575
self._model.fit(x=x_preprocessed, y=y_preprocessed, batch_size=batch_size, epochs=nb_epochs, **kwargs)

art/utils.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -535,21 +535,27 @@ def check_and_transform_label_format(
535535
:return: Labels with shape `(nb_samples, nb_classes)` (one-hot) or `(nb_samples,)` (index).
536536
"""
537537
if labels is not None:
538-
if len(labels.shape) == 2 and labels.shape[1] > 1:
538+
if len(labels.shape) == 2 and labels.shape[1] > 1: # multi-class, one-hot encoded
539539
if not return_one_hot:
540540
labels = np.argmax(labels, axis=1)
541-
elif len(labels.shape) == 2 and labels.shape[1] == 1 and nb_classes is not None and nb_classes > 2:
542-
labels = np.squeeze(labels)
541+
labels = np.expand_dims(labels, axis=1)
542+
elif (
543+
len(labels.shape) == 2 and labels.shape[1] == 1 and nb_classes is not None and nb_classes > 2
544+
): # multi-class, index labels
543545
if return_one_hot:
544546
labels = to_categorical(labels, nb_classes)
545-
elif len(labels.shape) == 2 and labels.shape[1] == 1 and nb_classes is not None and nb_classes == 2:
546-
pass
547-
elif len(labels.shape) == 1:
547+
else:
548+
labels = np.expand_dims(labels, axis=1)
549+
elif (
550+
len(labels.shape) == 2 and labels.shape[1] == 1 and nb_classes is not None and nb_classes == 2
551+
): # binary, index labels
548552
if return_one_hot:
549-
if nb_classes == 2:
550-
labels = np.expand_dims(labels, axis=1)
551-
else:
552-
labels = to_categorical(labels, nb_classes)
553+
labels = to_categorical(labels, nb_classes)
554+
elif len(labels.shape) == 1: # index labels
555+
if return_one_hot:
556+
labels = to_categorical(labels, nb_classes)
557+
else:
558+
labels = np.expand_dims(labels, axis=1)
553559
else:
554560
raise ValueError(
555561
"Shape of labels not recognised."

tests/test_utils.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -171,26 +171,44 @@ def test_to_categorical(self):
171171

172172
def test_check_and_transform_label_format(self):
173173
labels_expected = np.array([[0, 0, 0, 1, 0], [0, 1, 0, 0, 0], [0, 0, 0, 0, 1]])
174+
labels_expected_binary = np.array([[0, 1], [1, 0], [0, 1]])
174175

175176
# test input shape (nb_samples,)
176177
labels = np.array([3, 1, 4])
177-
labels_transformed = check_and_transform_label_format(labels, 5)
178+
labels_transformed = check_and_transform_label_format(labels, nb_classes=5, return_one_hot=True)
178179
np.testing.assert_array_equal(labels_transformed, labels_expected)
179180

180181
# test input shape (nb_samples, 1)
181182
labels = np.array([[3], [1], [4]])
182-
labels_transformed = check_and_transform_label_format(labels, 5)
183+
labels_transformed = check_and_transform_label_format(labels, nb_classes=5, return_one_hot=True)
183184
np.testing.assert_array_equal(labels_transformed, labels_expected)
184185

186+
# test input shape (nb_samples, 1) - binary
187+
labels = np.array([[1], [0], [1]])
188+
labels_transformed = check_and_transform_label_format(labels, nb_classes=2, return_one_hot=True)
189+
np.testing.assert_array_equal(labels_transformed, labels_expected_binary)
190+
191+
# test input shape (nb_samples, 1) - binary
192+
labels = np.array([[0, 1], [1, 0], [0, 1]])
193+
labels_transformed = check_and_transform_label_format(labels, nb_classes=2, return_one_hot=True)
194+
np.testing.assert_array_equal(labels_transformed, labels_expected_binary)
195+
185196
# test input shape (nb_samples, nb_classes)
186197
labels = np.array([[0, 0, 0, 1, 0], [0, 1, 0, 0, 0], [0, 0, 0, 0, 1]])
187-
labels_transformed = check_and_transform_label_format(labels, 5)
198+
labels_transformed = check_and_transform_label_format(labels, nb_classes=5, return_one_hot=True)
188199
np.testing.assert_array_equal(labels_transformed, labels_expected)
189200

190201
# test input shape (nb_samples, nb_classes) with return_one_hot=False
191202
labels = np.array([[0, 0, 0, 1, 0], [0, 1, 0, 0, 0], [0, 0, 0, 0, 1]])
192-
labels_transformed = check_and_transform_label_format(labels, 5, return_one_hot=False)
193-
np.testing.assert_array_equal(labels_transformed, np.argmax(labels_expected, axis=1))
203+
labels_transformed = check_and_transform_label_format(labels, nb_classes=5, return_one_hot=False)
204+
np.testing.assert_array_equal(labels_transformed, np.expand_dims(np.argmax(labels_expected, axis=1), axis=1))
205+
206+
# test input shape (nb_samples, 1) - binary
207+
labels = np.array([[1], [0], [1]])
208+
labels_transformed = check_and_transform_label_format(labels, nb_classes=2, return_one_hot=False)
209+
np.testing.assert_array_equal(
210+
labels_transformed, np.expand_dims(np.argmax(labels_expected_binary, axis=1), axis=1)
211+
)
194212

195213
# ValueError for len(labels.shape) > 2
196214
labels = np.array([[[0, 0, 0, 1, 0], [0, 1, 0, 0, 0], [0, 0, 0, 0, 1]]])

0 commit comments

Comments
 (0)