Skip to content

Commit 4461f32

Browse files
Update the one-pixel shortcut attack and its unit tests based on review
Signed-off-by: Nicholas Audric Adriel <[email protected]>
1 parent 35c0caa commit 4461f32

File tree

2 files changed

+21
-10
lines changed

2 files changed

+21
-10
lines changed

art/attacks/poisoning/one_pixel_shortcut_attack.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
This module implements One Pixel Shortcut attacks on Deep Neural Networks.
2020
"""
2121

22-
from typing import Optional, Tuple
2322

2423
import numpy as np
2524

@@ -33,6 +32,8 @@ class OnePixelShortcutAttack(PoisoningAttackBlackBox):
3332
for each class by maximizing a mean-minus-variance objective over that class's
3433
images. The found pixel coordinate and color are applied to all images of the class
3534
(labels remain unchanged). Reference: Wu et al. (ICLR 2023).
35+
36+
| Paper link: https://arxiv.org/abs/2205.12141
3637
"""
3738

3839
attack_params: list = [] # No external parameters for this attack
@@ -45,7 +46,7 @@ def _check_params(self):
4546
# No parameters to validate
4647
pass
4748

48-
def poison(self, x: np.ndarray, y: Optional[np.ndarray] = None, **kwargs) -> Tuple[np.ndarray, np.ndarray]:
49+
def poison(self, x: np.ndarray, y: np.ndarray | None = None, **kwargs) -> tuple[np.ndarray, np.ndarray]:
4950
"""
5051
Generate an OPS-poisoned dataset from clean data.
5152

tests/attacks/poison/test_one_pixel_shortcut_attack.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,15 @@
2020

2121
import numpy as np
2222
import pytest
23+
from unittest.mock import patch
2324

2425
from art.attacks.poisoning.one_pixel_shortcut_attack import OnePixelShortcutAttack
25-
from unittest.mock import patch
2626
from tests.utils import ARTTestException
2727

2828
logger = logging.getLogger(__name__)
2929

3030

31+
@pytest.mark.framework_agnostic
3132
def test_one_pixel_per_image_and_label_preservation():
3233
try:
3334
x = np.zeros((4, 3, 3))
@@ -49,6 +50,7 @@ def test_one_pixel_per_image_and_label_preservation():
4950
raise ARTTestException("Pixel change or label consistency check failed") from e
5051

5152

53+
@pytest.mark.framework_agnostic
5254
def test_missing_labels_raises_error():
5355
try:
5456
x = np.zeros((3, 5, 5))
@@ -59,6 +61,7 @@ def test_missing_labels_raises_error():
5961
raise ARTTestException("Expected error not raised for missing labels") from e
6062

6163

64+
@pytest.mark.framework_agnostic
6265
def test_multi_channel_consistency():
6366
try:
6467
x = np.zeros((2, 2, 2, 3))
@@ -82,6 +85,7 @@ def test_multi_channel_consistency():
8285
raise ARTTestException("Multi-channel image consistency check failed") from e
8386

8487

88+
@pytest.mark.only_with_platform("pytorch")
8589
def test_one_pixel_effect_with_pytorchclassifier():
8690
try:
8791
import torch
@@ -92,11 +96,11 @@ def test_one_pixel_effect_with_pytorchclassifier():
9296
np.random.seed(0)
9397

9498
# Create a toy dataset: 2x2 grayscale images, 2 classes
95-
X = np.zeros((8, 1, 2, 2), dtype=np.float32)
99+
x = np.zeros((8, 1, 2, 2), dtype=np.float32)
96100
for i in range(4):
97-
X[i, 0, 0, 0] = i * 0.25 # class 0
101+
x[i, 0, 0, 0] = i * 0.25 # class 0
98102
for i in range(4, 8):
99-
X[i, 0, 0, 1] = (i - 4) * 0.25 # class 1
103+
x[i, 0, 0, 1] = (i - 4) * 0.25 # class 1
100104
y = np.array([0, 0, 0, 0, 1, 1, 1, 1])
101105

102106
model_clean = nn.Sequential(nn.Flatten(), nn.Linear(4, 2))
@@ -114,7 +118,7 @@ def test_one_pixel_effect_with_pytorchclassifier():
114118
acc_clean = np.mean(preds_clean.argmax(axis=1) == y)
115119

116120
ops_attack = OnePixelShortcutAttack()
117-
X_poison, y_poison = ops_attack.poison(X.copy(), y.copy())
121+
x_poison, y_poison = ops_attack.poison(X.copy(), y.copy())
118122

119123
model_poisoned = nn.Sequential(nn.Flatten(), nn.Linear(4, 2))
120124
classifier_poisoned = PyTorchClassifier(
@@ -125,24 +129,25 @@ def test_one_pixel_effect_with_pytorchclassifier():
125129
nb_classes=2,
126130
)
127131
classifier_poisoned.fit(
128-
X_poison,
132+
x_poison,
129133
y_poison,
130134
nb_epochs=10,
131135
batch_size=4,
132136
verbose=0,
133137
)
134-
preds_poisoned = classifier_poisoned.predict(X_poison)
138+
preds_poisoned = classifier_poisoned.predict(x_poison)
135139
acc_poisoned = np.mean(preds_poisoned.argmax(axis=1) == y_poison)
136140

137141
# Adjusted assertions for robustness
138-
assert acc_poisoned >= 1.0, f"Expected 100% poisoned accuracy, got {acc_poisoned:.3f}"
142+
assert acc_poisoned == 1.0, f"Expected 100% poisoned accuracy, got {acc_poisoned:.3f}"
139143
assert acc_clean < 0.95, f"Expected clean accuracy < 95%, got {acc_clean:.3f}"
140144

141145
except Exception as e:
142146
logger.warning("test_one_pixel_effect_with_pytorchclassifier failed: %s", e)
143147
raise ARTTestException("PyTorchClassifier integration with OPS attack failed") from e
144148

145149

150+
@pytest.mark.framework_agnostic
146151
def test_check_params_noop():
147152
try:
148153
attack = OnePixelShortcutAttack()
@@ -154,6 +159,7 @@ def test_check_params_noop():
154159
raise ARTTestException("Parameter check method failed unexpectedly") from e
155160

156161

162+
@pytest.mark.framework_agnostic
157163
def test_ambiguous_layout_nhwc():
158164
try:
159165
# Shape (N=1, H=3, W=2, C=3) - ambiguous, both 3 could be channels
@@ -173,6 +179,7 @@ def test_ambiguous_layout_nhwc():
173179
raise ARTTestException("Ambiguous NHWC layout handling failed") from e
174180

175181

182+
@pytest.mark.framework_agnostic
176183
def test_ambiguous_layout_nchw():
177184
try:
178185
# Shape (N=1, C=2, H=3, W=2) - ambiguous, no dim equals 1/3/4
@@ -192,6 +199,7 @@ def test_ambiguous_layout_nchw():
192199
raise ARTTestException("Ambiguous NCHW layout handling failed") from e
193200

194201

202+
@pytest.mark.framework_agnostic
195203
def test_unsupported_input_shape_raises_error():
196204
try:
197205
x = np.zeros((5, 5), dtype=np.float32) # 2D input (unsupported)
@@ -203,6 +211,7 @@ def test_unsupported_input_shape_raises_error():
203211
raise ARTTestException("ValueError not raised for unsupported input shape") from e
204212

205213

214+
@pytest.mark.framework_agnostic
206215
def test_one_hot_labels_preserve_format():
207216
try:
208217
# Two 2x2 grayscale images, with one-hot labels for classes 0 and 1
@@ -221,6 +230,7 @@ def test_one_hot_labels_preserve_format():
221230
raise ARTTestException("One-hot label handling failed") from e
222231

223232

233+
@pytest.mark.framework_agnostic
224234
def test_class_skipping_when_no_samples():
225235
try:
226236
# Small dataset: 1 image 2x2, class 0

0 commit comments

Comments
 (0)