Skip to content

Commit b970b30

Browse files
author
Beat Buesser
committed
Merge branch 'abigailgold-dev_1.11.0_label_fixes' into development_maintenance_111
2 parents 24aad8d + a4ae43f commit b970b30

File tree

5 files changed

+66
-20
lines changed

5 files changed

+66
-20
lines changed

art/attacks/inference/attribute_inference/black_box.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def fit(self, x: np.ndarray, y: Optional[np.ndarray] = None) -> None:
155155
if ClassifierMixin in type(self.estimator).__mro__:
156156
predictions = np.array([np.argmax(arr) for arr in self.estimator.predict(x)]).reshape(-1, 1)
157157
if y is not None:
158-
y = check_and_transform_label_format(y, nb_classes=len(np.unique(y)), return_one_hot=True)
158+
y = check_and_transform_label_format(y, nb_classes=self.estimator.nb_classes, return_one_hot=True)
159159
else: # Regression model
160160
if self.scale_range is not None:
161161
predictions = minmax_scale(self.estimator.predict(x).reshape(-1, 1), feature_range=self.scale_range)
@@ -237,7 +237,7 @@ def infer(self, x: np.ndarray, y: Optional[np.ndarray] = None, **kwargs) -> np.n
237237
else:
238238
x_test = np.concatenate((x, pred), axis=1).astype(np.float32)
239239
if y is not None:
240-
y = check_and_transform_label_format(y, nb_classes=len(np.unique(y)), return_one_hot=True)
240+
y = check_and_transform_label_format(y, nb_classes=self.estimator.nb_classes, return_one_hot=True)
241241

242242
if y is not None:
243243
x_test = np.concatenate((x_test, y), axis=1)

art/attacks/inference/attribute_inference/true_label_baseline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def fit(self, x: np.ndarray, y: np.ndarray) -> None:
158158
normalized_labels = y * self.prediction_normal_factor
159159
normalized_labels = normalized_labels.reshape(-1, 1)
160160
else:
161-
normalized_labels = check_and_transform_label_format(y, nb_classes=len(np.unique(y)), return_one_hot=True)
161+
normalized_labels = check_and_transform_label_format(y, return_one_hot=True)
162162
x_train = np.concatenate((np.delete(x, self.attack_feature, 1), normalized_labels), axis=1).astype(np.float32)
163163

164164
# train attack model
@@ -194,7 +194,7 @@ def infer(self, x: np.ndarray, y: Optional[np.ndarray] = None, **kwargs) -> np.n
194194
normalized_labels = y * self.prediction_normal_factor
195195
normalized_labels = normalized_labels.reshape(-1, 1)
196196
else:
197-
normalized_labels = check_and_transform_label_format(y, nb_classes=len(np.unique(y)), return_one_hot=True)
197+
normalized_labels = check_and_transform_label_format(y, return_one_hot=True)
198198
x_test = np.concatenate((x, normalized_labels), axis=1).astype(np.float32)
199199

200200
predictions = self.attack_model.predict(x_test).astype(np.float32)

art/attacks/inference/membership_inference/black_box.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -187,8 +187,8 @@ def fit( # pylint: disable=W0613
187187
raise ValueError("Shape of test_x does not match input_shape of estimator")
188188

189189
if not self._regressor_model:
190-
y = check_and_transform_label_format(y, len(np.unique(y)), return_one_hot=True)
191-
test_y = check_and_transform_label_format(test_y, len(np.unique(test_y)), return_one_hot=True)
190+
y = check_and_transform_label_format(y, self.estimator.nb_classes, return_one_hot=True) # type: ignore
191+
test_y = check_and_transform_label_format(test_y, self.estimator.nb_classes, return_one_hot=True)
192192

193193
if y.shape[0] != x.shape[0]: # pragma: no cover
194194
raise ValueError("Number of rows in x and y do not match")
@@ -258,7 +258,7 @@ def fit( # pylint: disable=W0613
258258
loss.backward()
259259
optimizer.step()
260260
else:
261-
y_ready = check_and_transform_label_format(y_new, len(np.unique(y_new)), return_one_hot=False)
261+
y_ready = check_and_transform_label_format(y_new, nb_classes=2, return_one_hot=False)
262262
self.attack_model.fit(np.c_[x_1, x_2], y_ready.ravel()) # type: ignore
263263

264264
def infer(self, x: np.ndarray, y: Optional[np.ndarray] = None, **kwargs) -> np.ndarray:
@@ -285,7 +285,7 @@ def infer(self, x: np.ndarray, y: Optional[np.ndarray] = None, **kwargs) -> np.n
285285
probabilities = False
286286

287287
if not self._regressor_model:
288-
y = check_and_transform_label_format(y, len(np.unique(y)), return_one_hot=True)
288+
y = check_and_transform_label_format(y, self.estimator.nb_classes, return_one_hot=True)
289289

290290
if y is None:
291291
raise ValueError("None value detected.")

art/utils.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -686,26 +686,33 @@ def check_and_transform_label_format(
686686
if not return_one_hot:
687687
labels_return = np.argmax(labels, axis=1)
688688
labels_return = np.expand_dims(labels_return, axis=1)
689-
elif (
690-
len(labels.shape) == 2 and labels.shape[1] == 1 and nb_classes is not None and nb_classes > 2
691-
): # multi-class, index labels
692-
if return_one_hot:
693-
labels_return = to_categorical(labels, nb_classes)
689+
elif len(labels.shape) == 2 and labels.shape[1] == 1:
690+
if nb_classes is None:
691+
nb_classes = np.max(labels) + 1
692+
if nb_classes > 2: # multi-class, index labels
693+
if return_one_hot:
694+
labels_return = to_categorical(labels, nb_classes)
695+
else:
696+
labels_return = np.expand_dims(labels, axis=1)
697+
elif nb_classes == 2: # binary, index labels
698+
if return_one_hot:
699+
labels_return = to_categorical(labels, nb_classes)
694700
else:
695-
labels_return = np.expand_dims(labels, axis=1)
696-
elif (
697-
len(labels.shape) == 2 and labels.shape[1] == 1 and nb_classes is not None and nb_classes == 2
698-
): # binary, index labels
699-
if return_one_hot:
700-
labels_return = to_categorical(labels, nb_classes)
701+
raise ValueError(
702+
"Shape of labels not recognised."
703+
"Please provide labels in shape (nb_samples,) or (nb_samples, "
704+
"nb_classes)"
705+
)
701706
elif len(labels.shape) == 1: # index labels
702707
if return_one_hot:
703708
labels_return = to_categorical(labels, nb_classes)
704709
else:
705710
labels_return = np.expand_dims(labels, axis=1)
706711
else:
707712
raise ValueError(
708-
"Shape of labels not recognised." "Please provide labels in shape (nb_samples,) or (nb_samples, nb_classes)"
713+
"Shape of labels not recognised."
714+
"Please provide labels in shape (nb_samples,) or (nb_samples, "
715+
"nb_classes)"
709716
)
710717

711718
return labels_return

tests/test_utils.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,45 @@ def test_check_and_transform_label_format(self):
210210
labels_transformed, np.expand_dims(np.argmax(labels_expected_binary, axis=1), axis=1)
211211
)
212212

213+
# with no nb_classes
214+
215+
# test input shape (nb_samples,)
216+
labels = np.array([3, 1, 4])
217+
labels_transformed = check_and_transform_label_format(labels, return_one_hot=True)
218+
np.testing.assert_array_equal(labels_transformed, labels_expected)
219+
220+
# test input shape (nb_samples, 1)
221+
labels = np.array([[3], [1], [4]])
222+
labels_transformed = check_and_transform_label_format(labels, return_one_hot=True)
223+
np.testing.assert_array_equal(labels_transformed, labels_expected)
224+
225+
# test input shape (nb_samples, 1) - binary
226+
labels = np.array([[1], [0], [1]])
227+
labels_transformed = check_and_transform_label_format(labels, return_one_hot=True)
228+
np.testing.assert_array_equal(labels_transformed, labels_expected_binary)
229+
230+
# test input shape (nb_samples, 1) - binary
231+
labels = np.array([[0, 1], [1, 0], [0, 1]])
232+
labels_transformed = check_and_transform_label_format(labels, return_one_hot=True)
233+
np.testing.assert_array_equal(labels_transformed, labels_expected_binary)
234+
235+
# test input shape (nb_samples, nb_classes)
236+
labels = np.array([[0, 0, 0, 1, 0], [0, 1, 0, 0, 0], [0, 0, 0, 0, 1]])
237+
labels_transformed = check_and_transform_label_format(labels, return_one_hot=True)
238+
np.testing.assert_array_equal(labels_transformed, labels_expected)
239+
240+
# test input shape (nb_samples, nb_classes) with return_one_hot=False
241+
labels = np.array([[0, 0, 0, 1, 0], [0, 1, 0, 0, 0], [0, 0, 0, 0, 1]])
242+
labels_transformed = check_and_transform_label_format(labels, return_one_hot=False)
243+
np.testing.assert_array_equal(labels_transformed, np.expand_dims(np.argmax(labels_expected, axis=1), axis=1))
244+
245+
# test input shape (nb_samples, 1) - binary
246+
labels = np.array([[1], [0], [1]])
247+
labels_transformed = check_and_transform_label_format(labels, return_one_hot=False)
248+
np.testing.assert_array_equal(
249+
labels_transformed, np.expand_dims(np.argmax(labels_expected_binary, axis=1), axis=1)
250+
)
251+
213252
# ValueError for len(labels.shape) > 2
214253
labels = np.array([[[0, 0, 0, 1, 0], [0, 1, 0, 0, 0], [0, 0, 0, 0, 1]]])
215254
with self.assertRaises(ValueError):

0 commit comments

Comments
 (0)