Skip to content

Commit fbed197

Browse files
Add OnePixelShortcutAttack poisoning attack and unit tests
1 parent 0116171 commit fbed197

File tree

3 files changed

+291
-0
lines changed

3 files changed

+291
-0
lines changed

art/attacks/poisoning/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,4 @@
1919
from art.attacks.poisoning.hidden_trigger_backdoor.hidden_trigger_backdoor_pytorch import HiddenTriggerBackdoorPyTorch
2020
from art.attacks.poisoning.hidden_trigger_backdoor.hidden_trigger_backdoor_keras import HiddenTriggerBackdoorKeras
2121
from art.attacks.poisoning.sleeper_agent_attack import SleeperAgentAttack
22+
from art.attacks.poisoning.one_pixel_shortcut_attack import OnePixelShortcutAttack
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
# MIT License
2+
#
3+
# Copyright (C) The Adversarial Robustness Toolbox (ART) Authors 2025
4+
#
5+
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
6+
# documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
7+
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit
8+
# persons to whom the Software is furnished to do so, subject to the following conditions:
9+
#
10+
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the
11+
# Software.
12+
#
13+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE
14+
# WARRANTIES of MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
15+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE for any claim, damages or other liability, whether in an action of contract,
16+
# TORT OR OTHERWISE, ARISING from, out of or in connection with the software or the use or other dealings in the
17+
# Software.
18+
"""
19+
This module implements One Pixel Shortcut attacks on Deep Neural Networks.
20+
"""
21+
22+
import numpy as np
23+
24+
from art.attacks.attack import PoisoningAttackBlackBox
25+
26+
class OnePixelShortcutAttack(PoisoningAttackBlackBox):
27+
"""
28+
One-Pixel Shortcut (OPS) poisoning attack.
29+
This attack finds a single pixel (and channel value) that acts as a "shortcut"
30+
for each class by maximizing a mean-minus-variance objective over that class's
31+
images. The found pixel coordinate and color are applied to all images of the class
32+
(labels remain unchanged). Reference: Wu et al. (ICLR 2023).
33+
"""
34+
attack_params: list = [] # No external parameters for this attack
35+
_estimator_requirements: tuple = ()
36+
37+
def __init__(self):
38+
super().__init__()
39+
40+
def _check_params(self):
41+
# No parameters to validate
42+
pass
43+
44+
def poison(self, x: np.ndarray, y: np.ndarray = None, **kwargs):
45+
"""
46+
Generate an OPS-poisoned dataset from clean data.
47+
48+
:param x: Clean input samples, as a Numpy array of shape (N, H, W, C) or (N, C, H, W), with values in [0, 1].
49+
:param y: Corresponding labels (shape (N,) or one-hot (N, K)). Required for class-wise perturbation.
50+
:return: Tuple (x_poisoned, y_poisoned) with one pixel modified per image.
51+
"""
52+
if y is None:
53+
raise ValueError("Labels y must be provided for the One-Pixel Shortcut attack.")
54+
# Copy labels to return (labels are not changed by poisoning)
55+
y_poison = y.copy()
56+
57+
# Convert inputs to numpy array (if not already) and determine channel format
58+
x_array = np.array(x, copy=False)
59+
if x_array.ndim == 3:
60+
# Input shape (N, H, W) - single-channel images without explicit channel dim
61+
x_orig = x_array.reshape((x_array.shape[0], x_array.shape[1], x_array.shape[2], 1)).astype(np.float32)
62+
channels_first = False
63+
grayscale = True
64+
elif x_array.ndim == 4:
65+
# Determine if format is NCHW or NHWC by examining dimensions
66+
# Assume channel count is 1, 3, or 4 for common cases (grayscale, RGB, RGBA)
67+
if x_array.shape[1] in (1, 3, 4) and x_array.shape[-1] not in (1, 3, 4):
68+
# Likely (N, C, H, W) format
69+
x_orig = np.transpose(x_array, (0, 2, 3, 1)).astype(np.float32)
70+
channels_first = True
71+
elif x_array.shape[-1] in (1, 3, 4) and x_array.shape[1] not in (1, 3, 4):
72+
# Likely (N, H, W, C) format
73+
x_orig = x_array.astype(np.float32)
74+
channels_first = False
75+
else:
76+
# Ambiguous case: if both middle and last dims could be channels (e.g. tiny images)
77+
# Default to treating last dimension as channels if it matches a known channel count
78+
if x_array.shape[-1] in (1, 3, 4):
79+
x_orig = x_array.astype(np.float32)
80+
channels_first = False
81+
else:
82+
x_orig = np.transpose(x_array, (0, 2, 3, 1)).astype(np.float32)
83+
channels_first = True
84+
grayscale = (x_orig.shape[3] == 1)
85+
else:
86+
raise ValueError(f"Unsupported input tensor shape: {x_array.shape}")
87+
88+
# x_orig is now (N, H, W, C) in float32
89+
n, h, w, c = x_orig.shape
90+
# Prepare class index labels
91+
labels = y.copy()
92+
if labels.ndim > 1:
93+
labels = labels.argmax(axis=1)
94+
labels = labels.astype(int)
95+
96+
# Initialize output poisoned data array
97+
x_poison = x_orig.copy()
98+
99+
# Compute optimal pixel for each class
100+
classes = np.unique(labels)
101+
for cls in classes:
102+
idx = np.where(labels == cls)[0]
103+
if idx.size == 0:
104+
continue # skip if no samples for this class
105+
imgs_c = x_orig[idx] # subset of images of class `cls`, shape (n_c, H, W, C)
106+
best_score = -np.inf
107+
best_coord = None
108+
best_color = None
109+
# Determine target color options: extremes (0 or 1 in each channel)
110+
if c == 1:
111+
target_options = [
112+
np.array([0.0], dtype=x_orig.dtype),
113+
np.array([1.0], dtype=x_orig.dtype),
114+
]
115+
else:
116+
target_options = [
117+
np.array(bits, dtype=x_orig.dtype)
118+
for bits in np.ndindex(*(2,) * c)
119+
]
120+
# Evaluate each candidate color
121+
for target_vec in target_options:
122+
# Compute per-image average difference from target for all pixels
123+
diffs = np.abs(imgs_c - target_vec) # shape (n_c, H, W, C)
124+
per_image_diff = diffs.mean(axis=3) # shape (n_c, H, W), mean diff per image at each pixel
125+
# Compute score = mean - var for each pixel position (vectorized over HxW)
126+
mean_diff_map = per_image_diff.mean(axis=0) # shape (H, W)
127+
var_diff_map = per_image_diff.var(axis=0) # shape (H, W)
128+
score_map = mean_diff_map - var_diff_map # shape (H, W)
129+
# Find the pixel with maximum score for this target
130+
max_idx_flat = np.argmax(score_map)
131+
max_score = score_map.ravel()[max_idx_flat]
132+
if max_score > best_score:
133+
best_score = float(max_score)
134+
# Convert flat index to 2D coordinates (i, j)
135+
best_coord = (max_idx_flat // w, max_idx_flat % w)
136+
best_color = target_vec
137+
# Apply the best pixel perturbation to all images of this class
138+
if best_coord is not None:
139+
i_star, j_star = best_coord
140+
x_poison[idx, i_star, j_star, :] = best_color
141+
142+
# Restore original data format and type
143+
if channels_first:
144+
x_poison = np.transpose(x_poison, (0, 3, 1, 2))
145+
if grayscale:
146+
x_poison = x_poison.reshape(n, h, w)
147+
x_poison = x_poison.astype(x_array.dtype)
148+
return x_poison, y_poison
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
# MIT License
2+
#
3+
# Copyright (C) The Adversarial Robustness Toolbox (ART) Authors 2025
4+
#
5+
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
6+
# documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
7+
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit
8+
# persons to whom the Software is furnished to do so, subject to the following conditions:
9+
#
10+
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the
11+
# Software.
12+
#
13+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE
14+
# WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
15+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE for any claim, damages or other liability, whether in an action of contract,
16+
# TORT OR OTHERWISE, ARISING from, out of or in connection with the software or the use or other dealings in the
17+
# Software.
18+
19+
import logging
20+
21+
import numpy as np
22+
import pytest
23+
24+
from art.attacks.poisoning.one_pixel_shortcut_attack import OnePixelShortcutAttack
25+
from tests.utils import ARTTestException
26+
27+
logger = logging.getLogger(__name__)
28+
29+
def test_one_pixel_per_image_and_label_preservation():
30+
try:
31+
x = np.zeros((4, 3, 3))
32+
y = np.array([0, 0, 1, 1])
33+
attack = OnePixelShortcutAttack()
34+
x_p, y_p = attack.poison(x.copy(), y.copy())
35+
36+
assert x_p.shape == x.shape
37+
assert np.array_equal(y_p, y)
38+
39+
changes = np.sum(x_p != x, axis=(1, 2))
40+
assert np.all(changes == 1)
41+
42+
coords = [tuple(np.argwhere(x_p[i] != x[i])[0]) for i in range(x.shape[0])]
43+
assert coords[0] == coords[1]
44+
assert coords[2] == coords[3]
45+
except Exception as e:
46+
logger.warning("test_one_pixel_per_image_and_label_preservation failed: %s", e)
47+
raise ARTTestException("Pixel change or label consistency check failed") from e
48+
49+
def test_missing_labels_raises_error():
50+
try:
51+
x = np.zeros((3, 5, 5))
52+
with pytest.raises(ValueError):
53+
OnePixelShortcutAttack().poison(x.copy(), None)
54+
except Exception as e:
55+
logger.warning("test_missing_labels_raises_error failed: %s", e)
56+
raise ARTTestException("Expected error not raised for missing labels") from e
57+
58+
def test_multi_channel_consistency():
59+
try:
60+
x = np.zeros((2, 2, 2, 3))
61+
y = np.array([0, 1])
62+
attack = OnePixelShortcutAttack()
63+
x_p, y_p = attack.poison(x.copy(), y.copy())
64+
65+
assert x_p.shape == x.shape
66+
assert np.array_equal(y_p, y)
67+
68+
diff_any = np.any(x_p != x, axis=3)
69+
changes = np.sum(diff_any, axis=(1, 2))
70+
assert np.all(changes == 1)
71+
72+
coords0 = np.argwhere(diff_any[0])
73+
coords1 = np.argwhere(diff_any[1])
74+
assert coords0.shape[0] == 1
75+
assert coords1.shape[0] == 1
76+
except Exception as e:
77+
logger.warning("test_multi_channel_consistency failed: %s", e)
78+
raise ARTTestException("Multi-channel image consistency check failed") from e
79+
80+
def test_one_pixel_effect_with_pytorchclassifier():
81+
try:
82+
import torch
83+
import torch.nn as nn
84+
from art.estimators.classification import PyTorchClassifier
85+
86+
torch.manual_seed(0)
87+
np.random.seed(0)
88+
89+
# Create a toy dataset: 2x2 grayscale images, 2 classes
90+
X = np.zeros((8, 1, 2, 2), dtype=np.float32)
91+
for i in range(4):
92+
X[i, 0, 0, 0] = i * 0.25 # class 0
93+
for i in range(4, 8):
94+
X[i, 0, 0, 1] = (i - 4) * 0.25 # class 1
95+
y = np.array([0, 0, 0, 0, 1, 1, 1, 1])
96+
97+
model_clean = nn.Sequential(nn.Flatten(), nn.Linear(4, 2))
98+
loss_fn = nn.CrossEntropyLoss()
99+
100+
classifier_clean = PyTorchClassifier(
101+
model=model_clean,
102+
loss=loss_fn,
103+
optimizer=torch.optim.SGD(model_clean.parameters(), lr=0.1),
104+
input_shape=(1, 2, 2),
105+
nb_classes=2,
106+
)
107+
classifier_clean.fit(X, y, nb_epochs=10, batch_size=4, verbose=0)
108+
preds_clean = classifier_clean.predict(X)
109+
acc_clean = np.mean(preds_clean.argmax(axis=1) == y)
110+
111+
ops_attack = OnePixelShortcutAttack()
112+
X_poison, y_poison = ops_attack.poison(X.copy(), y.copy())
113+
114+
model_poisoned = nn.Sequential(nn.Flatten(), nn.Linear(4, 2))
115+
classifier_poisoned = PyTorchClassifier(
116+
model=model_poisoned,
117+
loss=loss_fn,
118+
optimizer=torch.optim.SGD(model_poisoned.parameters(), lr=0.1),
119+
input_shape=(1, 2, 2),
120+
nb_classes=2,
121+
)
122+
classifier_poisoned.fit(
123+
X_poison,
124+
y_poison,
125+
nb_epochs=10,
126+
batch_size=4,
127+
verbose=0,
128+
)
129+
preds_poisoned = classifier_poisoned.predict(X_poison)
130+
acc_poisoned = np.mean(preds_poisoned.argmax(axis=1) == y_poison)
131+
132+
# Adjusted assertions for robustness
133+
assert acc_poisoned >= 1.0, (
134+
f"Expected 100% poisoned accuracy, got {acc_poisoned:.3f}"
135+
)
136+
assert acc_clean < 0.95, f"Expected clean accuracy < 95%, got {acc_clean:.3f}"
137+
138+
except Exception as e:
139+
logger.warning("test_one_pixel_effect_with_pytorchclassifier failed: %s", e)
140+
raise ARTTestException(
141+
"PyTorchClassifier integration with OPS attack failed"
142+
) from e

0 commit comments

Comments
 (0)