Skip to content

Commit 35c0caa

Browse files
Fix OPS typing and tests; add coverage for edge branches
Signed-off-by: Nicholas Audric Adriel <[email protected]>
1 parent 15ff840 commit 35c0caa

File tree

2 files changed

+124
-17
lines changed

2 files changed

+124
-17
lines changed

art/attacks/poisoning/one_pixel_shortcut_attack.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,13 @@
1919
This module implements One Pixel Shortcut attacks on Deep Neural Networks.
2020
"""
2121

22+
from typing import Optional, Tuple
23+
2224
import numpy as np
2325

2426
from art.attacks.attack import PoisoningAttackBlackBox
2527

28+
2629
class OnePixelShortcutAttack(PoisoningAttackBlackBox):
2730
"""
2831
One-Pixel Shortcut (OPS) poisoning attack.
@@ -31,6 +34,7 @@ class OnePixelShortcutAttack(PoisoningAttackBlackBox):
3134
images. The found pixel coordinate and color are applied to all images of the class
3235
(labels remain unchanged). Reference: Wu et al. (ICLR 2023).
3336
"""
37+
3438
attack_params: list = [] # No external parameters for this attack
3539
_estimator_requirements: tuple = ()
3640

@@ -41,7 +45,7 @@ def _check_params(self):
4145
# No parameters to validate
4246
pass
4347

44-
def poison(self, x: np.ndarray, y: np.ndarray = None, **kwargs):
48+
def poison(self, x: np.ndarray, y: Optional[np.ndarray] = None, **kwargs) -> Tuple[np.ndarray, np.ndarray]:
4549
"""
4650
Generate an OPS-poisoned dataset from clean data.
4751
@@ -81,7 +85,7 @@ def poison(self, x: np.ndarray, y: np.ndarray = None, **kwargs):
8185
else:
8286
x_orig = np.transpose(x_array, (0, 2, 3, 1)).astype(np.float32)
8387
channels_first = True
84-
grayscale = (x_orig.shape[3] == 1)
88+
grayscale = x_orig.shape[3] == 1
8589
else:
8690
raise ValueError(f"Unsupported input tensor shape: {x_array.shape}")
8791

@@ -113,19 +117,16 @@ def poison(self, x: np.ndarray, y: np.ndarray = None, **kwargs):
113117
np.array([1.0], dtype=x_orig.dtype),
114118
]
115119
else:
116-
target_options = [
117-
np.array(bits, dtype=x_orig.dtype)
118-
for bits in np.ndindex(*(2,) * c)
119-
]
120+
target_options = [np.array(bits, dtype=x_orig.dtype) for bits in np.ndindex(*(2,) * c)]
120121
# Evaluate each candidate color
121122
for target_vec in target_options:
122123
# Compute per-image average difference from target for all pixels
123-
diffs = np.abs(imgs_c - target_vec) # shape (n_c, H, W, C)
124-
per_image_diff = diffs.mean(axis=3) # shape (n_c, H, W), mean diff per image at each pixel
124+
diffs = np.abs(imgs_c - target_vec) # shape (n_c, H, W, C)
125+
per_image_diff = diffs.mean(axis=3) # shape (n_c, H, W), mean diff per image at each pixel
125126
# Compute score = mean - var for each pixel position (vectorized over HxW)
126-
mean_diff_map = per_image_diff.mean(axis=0) # shape (H, W)
127-
var_diff_map = per_image_diff.var(axis=0) # shape (H, W)
128-
score_map = mean_diff_map - var_diff_map # shape (H, W)
127+
mean_diff_map = per_image_diff.mean(axis=0) # shape (H, W)
128+
var_diff_map = per_image_diff.var(axis=0) # shape (H, W)
129+
score_map = mean_diff_map - var_diff_map # shape (H, W)
129130
# Find the pixel with maximum score for this target
130131
max_idx_flat = np.argmax(score_map)
131132
max_score = score_map.ravel()[max_idx_flat]

tests/attacks/poison/test_one_pixel_shortcut_attack.py

Lines changed: 112 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,12 @@
2222
import pytest
2323

2424
from art.attacks.poisoning.one_pixel_shortcut_attack import OnePixelShortcutAttack
25+
from unittest.mock import patch
2526
from tests.utils import ARTTestException
2627

2728
logger = logging.getLogger(__name__)
2829

30+
2931
def test_one_pixel_per_image_and_label_preservation():
3032
try:
3133
x = np.zeros((4, 3, 3))
@@ -46,6 +48,7 @@ def test_one_pixel_per_image_and_label_preservation():
4648
logger.warning("test_one_pixel_per_image_and_label_preservation failed: %s", e)
4749
raise ARTTestException("Pixel change or label consistency check failed") from e
4850

51+
4952
def test_missing_labels_raises_error():
5053
try:
5154
x = np.zeros((3, 5, 5))
@@ -55,6 +58,7 @@ def test_missing_labels_raises_error():
5558
logger.warning("test_missing_labels_raises_error failed: %s", e)
5659
raise ARTTestException("Expected error not raised for missing labels") from e
5760

61+
5862
def test_multi_channel_consistency():
5963
try:
6064
x = np.zeros((2, 2, 2, 3))
@@ -77,6 +81,7 @@ def test_multi_channel_consistency():
7781
logger.warning("test_multi_channel_consistency failed: %s", e)
7882
raise ARTTestException("Multi-channel image consistency check failed") from e
7983

84+
8085
def test_one_pixel_effect_with_pytorchclassifier():
8186
try:
8287
import torch
@@ -130,13 +135,114 @@ def test_one_pixel_effect_with_pytorchclassifier():
130135
acc_poisoned = np.mean(preds_poisoned.argmax(axis=1) == y_poison)
131136

132137
# Adjusted assertions for robustness
133-
assert acc_poisoned >= 1.0, (
134-
f"Expected 100% poisoned accuracy, got {acc_poisoned:.3f}"
135-
)
138+
assert acc_poisoned >= 1.0, f"Expected 100% poisoned accuracy, got {acc_poisoned:.3f}"
136139
assert acc_clean < 0.95, f"Expected clean accuracy < 95%, got {acc_clean:.3f}"
137140

138141
except Exception as e:
139142
logger.warning("test_one_pixel_effect_with_pytorchclassifier failed: %s", e)
140-
raise ARTTestException(
141-
"PyTorchClassifier integration with OPS attack failed"
142-
) from e
143+
raise ARTTestException("PyTorchClassifier integration with OPS attack failed") from e
144+
145+
146+
def test_check_params_noop():
147+
try:
148+
attack = OnePixelShortcutAttack()
149+
# _check_params should do nothing (no error raised)
150+
result = attack._check_params()
151+
assert result is None
152+
except Exception as e:
153+
logger.warning("test_check_params_noop failed: %s", e)
154+
raise ARTTestException("Parameter check method failed unexpectedly") from e
155+
156+
157+
def test_ambiguous_layout_nhwc():
158+
try:
159+
# Shape (N=1, H=3, W=2, C=3) - ambiguous, both 3 could be channels
160+
x = np.zeros((1, 3, 2, 3), dtype=np.float32)
161+
y = np.array([0]) # single sample, class 0
162+
attack = OnePixelShortcutAttack()
163+
x_p, y_p = attack.poison(x.copy(), y.copy())
164+
# Output shape should match input shape
165+
assert x_p.shape == x.shape and x_p.dtype == x.dtype
166+
assert np.array_equal(y_p, y)
167+
# Verify exactly one pixel (position) was changed
168+
diff_any = np.any(x_p != x, axis=3) # collapse channel differences
169+
changes = np.sum(diff_any, axis=(1, 2))
170+
assert np.all(changes == 1)
171+
except Exception as e:
172+
logger.warning("test_ambiguous_layout_nhwc failed: %s", e)
173+
raise ARTTestException("Ambiguous NHWC layout handling failed") from e
174+
175+
176+
def test_ambiguous_layout_nchw():
177+
try:
178+
# Shape (N=1, C=2, H=3, W=2) - ambiguous, no dim equals 1/3/4
179+
x = np.zeros((1, 2, 3, 2), dtype=np.float32)
180+
y = np.array([0])
181+
attack = OnePixelShortcutAttack()
182+
x_p, y_p = attack.poison(x.copy(), y.copy())
183+
# Output shape unchanged
184+
assert x_p.shape == x.shape and x_p.dtype == x.dtype
185+
assert np.array_equal(y_p, y)
186+
# Exactly one pixel changed (channels-first input)
187+
diff_any = np.any(x_p != x, axis=1) # collapse channels axis
188+
changes = np.sum(diff_any, axis=(1, 2))
189+
assert np.all(changes == 1)
190+
except Exception as e:
191+
logger.warning("test_ambiguous_layout_nchw failed: %s", e)
192+
raise ARTTestException("Ambiguous NCHW layout handling failed") from e
193+
194+
195+
def test_unsupported_input_shape_raises_error():
196+
try:
197+
x = np.zeros((5, 5), dtype=np.float32) # 2D input (unsupported)
198+
y = np.array([0, 1, 2, 3, 4]) # Dummy labels (not actually used due to error)
199+
with pytest.raises(ValueError):
200+
OnePixelShortcutAttack().poison(x, y)
201+
except Exception as e:
202+
logger.warning("test_unsupported_input_shape_raises_error failed: %s", e)
203+
raise ARTTestException("ValueError not raised for unsupported input shape") from e
204+
205+
206+
def test_one_hot_labels_preserve_format():
207+
try:
208+
# Two 2x2 grayscale images, with one-hot labels for classes 0 and 1
209+
x = np.zeros((2, 2, 2), dtype=np.float32) # shape (N=2, H=2, W=2)
210+
y = np.array([[1, 0], [0, 1]], dtype=np.float32) # shape (2, 2) one-hot labels
211+
attack = OnePixelShortcutAttack()
212+
x_p, y_p = attack.poison(x.copy(), y.copy())
213+
# Labels should remain one-hot and unchanged
214+
assert y_p.shape == y.shape
215+
assert np.array_equal(y_p, y)
216+
# Exactly one pixel changed per image
217+
changes = np.sum(x_p != x, axis=(1, 2))
218+
assert np.all(changes == 1)
219+
except Exception as e:
220+
logger.warning("test_one_hot_labels_preserve_format failed: %s", e)
221+
raise ARTTestException("One-hot label handling failed") from e
222+
223+
224+
def test_class_skipping_when_no_samples():
225+
try:
226+
# Small dataset: 1 image 2x2, class 0
227+
x = np.zeros((1, 2, 2), dtype=np.float32)
228+
y = np.array([0])
229+
attack = OnePixelShortcutAttack()
230+
# Baseline output without any patch
231+
x_ref, y_ref = attack.poison(x.copy(), y.copy())
232+
# Monkey-patch np.unique in attack module to add a non-existent class (e.g., class 1)
233+
dummy_class = np.array([1], dtype=int)
234+
orig_unique = np.unique
235+
with patch(
236+
"art.attacks.poisoning.one_pixel_shortcut_attack.np.unique",
237+
new=lambda arr: np.concatenate([orig_unique(arr), dummy_class]),
238+
):
239+
x_p, y_p = attack.poison(x.copy(), y.copy())
240+
# Output with dummy class skip should match baseline output
241+
assert np.array_equal(x_p, x_ref)
242+
assert np.array_equal(y_p, y_ref)
243+
# And still exactly one pixel changed
244+
changes = np.sum(x_p != x, axis=(1, 2))
245+
assert np.all(changes == 1)
246+
except Exception as e:
247+
logger.warning("test_class_skipping_when_no_samples failed: %s", e)
248+
raise ARTTestException("Class-skipping branch handling failed") from e

0 commit comments

Comments
 (0)