diff --git a/art/attacks/poisoning/__init__.py b/art/attacks/poisoning/__init__.py index fa62ad125a..cf3c349487 100644 --- a/art/attacks/poisoning/__init__.py +++ b/art/attacks/poisoning/__init__.py @@ -19,3 +19,4 @@ from art.attacks.poisoning.hidden_trigger_backdoor.hidden_trigger_backdoor_pytorch import HiddenTriggerBackdoorPyTorch from art.attacks.poisoning.hidden_trigger_backdoor.hidden_trigger_backdoor_keras import HiddenTriggerBackdoorKeras from art.attacks.poisoning.sleeper_agent_attack import SleeperAgentAttack +from art.attacks.poisoning.one_pixel_shortcut_attack import OnePixelShortcutAttack diff --git a/art/attacks/poisoning/one_pixel_shortcut_attack.py b/art/attacks/poisoning/one_pixel_shortcut_attack.py new file mode 100644 index 0000000000..52f2616b1e --- /dev/null +++ b/art/attacks/poisoning/one_pixel_shortcut_attack.py @@ -0,0 +1,149 @@ +# MIT License +# +# Copyright (C) The Adversarial Robustness Toolbox (ART) Authors 2025 +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated +# documentation files (the "Software"), to deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit +# persons to whom the Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the +# Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE +# WARRANTIES of MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE for any claim, damages or other liability, whether in an action of contract, +# TORT OR OTHERWISE, ARISING from, out of or in connection with the software or the use or other dealings in the +# Software. +""" +This module implements One Pixel Shortcut attacks on Deep Neural Networks. +""" + +from typing import Optional, Tuple + +import numpy as np + +from art.attacks.attack import PoisoningAttackBlackBox + + +class OnePixelShortcutAttack(PoisoningAttackBlackBox): + """ + One-Pixel Shortcut (OPS) poisoning attack. + This attack finds a single pixel (and channel value) that acts as a "shortcut" + for each class by maximizing a mean-minus-variance objective over that class's + images. The found pixel coordinate and color are applied to all images of the class + (labels remain unchanged). Reference: Wu et al. (ICLR 2023). + """ + + attack_params: list = [] # No external parameters for this attack + _estimator_requirements: tuple = () + + def __init__(self): + super().__init__() + + def _check_params(self): + # No parameters to validate + pass + + def poison(self, x: np.ndarray, y: Optional[np.ndarray] = None, **kwargs) -> Tuple[np.ndarray, np.ndarray]: + """ + Generate an OPS-poisoned dataset from clean data. + + :param x: Clean input samples, as a Numpy array of shape (N, H, W, C) or (N, C, H, W), with values in [0, 1]. + :param y: Corresponding labels (shape (N,) or one-hot (N, K)). Required for class-wise perturbation. + :return: Tuple (x_poisoned, y_poisoned) with one pixel modified per image. + """ + if y is None: + raise ValueError("Labels y must be provided for the One-Pixel Shortcut attack.") + # Copy labels to return (labels are not changed by poisoning) + y_poison = y.copy() + + # Convert inputs to numpy array (if not already) and determine channel format + x_array = np.array(x, copy=False) + if x_array.ndim == 3: + # Input shape (N, H, W) - single-channel images without explicit channel dim + x_orig = x_array.reshape((x_array.shape[0], x_array.shape[1], x_array.shape[2], 1)).astype(np.float32) + channels_first = False + grayscale = True + elif x_array.ndim == 4: + # Determine if format is NCHW or NHWC by examining dimensions + # Assume channel count is 1, 3, or 4 for common cases (grayscale, RGB, RGBA) + if x_array.shape[1] in (1, 3, 4) and x_array.shape[-1] not in (1, 3, 4): + # Likely (N, C, H, W) format + x_orig = np.transpose(x_array, (0, 2, 3, 1)).astype(np.float32) + channels_first = True + elif x_array.shape[-1] in (1, 3, 4) and x_array.shape[1] not in (1, 3, 4): + # Likely (N, H, W, C) format + x_orig = x_array.astype(np.float32) + channels_first = False + else: + # Ambiguous case: if both middle and last dims could be channels (e.g. tiny images) + # Default to treating last dimension as channels if it matches a known channel count + if x_array.shape[-1] in (1, 3, 4): + x_orig = x_array.astype(np.float32) + channels_first = False + else: + x_orig = np.transpose(x_array, (0, 2, 3, 1)).astype(np.float32) + channels_first = True + grayscale = x_orig.shape[3] == 1 + else: + raise ValueError(f"Unsupported input tensor shape: {x_array.shape}") + + # x_orig is now (N, H, W, C) in float32 + n, h, w, c = x_orig.shape + # Prepare class index labels + labels = y.copy() + if labels.ndim > 1: + labels = labels.argmax(axis=1) + labels = labels.astype(int) + + # Initialize output poisoned data array + x_poison = x_orig.copy() + + # Compute optimal pixel for each class + classes = np.unique(labels) + for cls in classes: + idx = np.where(labels == cls)[0] + if idx.size == 0: + continue # skip if no samples for this class + imgs_c = x_orig[idx] # subset of images of class `cls`, shape (n_c, H, W, C) + best_score = -np.inf + best_coord = None + best_color = None + # Determine target color options: extremes (0 or 1 in each channel) + if c == 1: + target_options = [ + np.array([0.0], dtype=x_orig.dtype), + np.array([1.0], dtype=x_orig.dtype), + ] + else: + target_options = [np.array(bits, dtype=x_orig.dtype) for bits in np.ndindex(*(2,) * c)] + # Evaluate each candidate color + for target_vec in target_options: + # Compute per-image average difference from target for all pixels + diffs = np.abs(imgs_c - target_vec) # shape (n_c, H, W, C) + per_image_diff = diffs.mean(axis=3) # shape (n_c, H, W), mean diff per image at each pixel + # Compute score = mean - var for each pixel position (vectorized over HxW) + mean_diff_map = per_image_diff.mean(axis=0) # shape (H, W) + var_diff_map = per_image_diff.var(axis=0) # shape (H, W) + score_map = mean_diff_map - var_diff_map # shape (H, W) + # Find the pixel with maximum score for this target + max_idx_flat = np.argmax(score_map) + max_score = score_map.ravel()[max_idx_flat] + if max_score > best_score: + best_score = float(max_score) + # Convert flat index to 2D coordinates (i, j) + best_coord = (max_idx_flat // w, max_idx_flat % w) + best_color = target_vec + # Apply the best pixel perturbation to all images of this class + if best_coord is not None: + i_star, j_star = best_coord + x_poison[idx, i_star, j_star, :] = best_color + + # Restore original data format and type + if channels_first: + x_poison = np.transpose(x_poison, (0, 3, 1, 2)) + if grayscale: + x_poison = x_poison.reshape(n, h, w) + x_poison = x_poison.astype(x_array.dtype) + return x_poison, y_poison diff --git a/tests/attacks/poison/test_one_pixel_shortcut_attack.py b/tests/attacks/poison/test_one_pixel_shortcut_attack.py new file mode 100644 index 0000000000..21aabae013 --- /dev/null +++ b/tests/attacks/poison/test_one_pixel_shortcut_attack.py @@ -0,0 +1,248 @@ +# MIT License +# +# Copyright (C) The Adversarial Robustness Toolbox (ART) Authors 2025 +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated +# documentation files (the "Software"), to deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit +# persons to whom the Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the +# Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE +# WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE for any claim, damages or other liability, whether in an action of contract, +# TORT OR OTHERWISE, ARISING from, out of or in connection with the software or the use or other dealings in the +# Software. + +import logging + +import numpy as np +import pytest + +from art.attacks.poisoning.one_pixel_shortcut_attack import OnePixelShortcutAttack +from unittest.mock import patch +from tests.utils import ARTTestException + +logger = logging.getLogger(__name__) + + +def test_one_pixel_per_image_and_label_preservation(): + try: + x = np.zeros((4, 3, 3)) + y = np.array([0, 0, 1, 1]) + attack = OnePixelShortcutAttack() + x_p, y_p = attack.poison(x.copy(), y.copy()) + + assert x_p.shape == x.shape + assert np.array_equal(y_p, y) + + changes = np.sum(x_p != x, axis=(1, 2)) + assert np.all(changes == 1) + + coords = [tuple(np.argwhere(x_p[i] != x[i])[0]) for i in range(x.shape[0])] + assert coords[0] == coords[1] + assert coords[2] == coords[3] + except Exception as e: + logger.warning("test_one_pixel_per_image_and_label_preservation failed: %s", e) + raise ARTTestException("Pixel change or label consistency check failed") from e + + +def test_missing_labels_raises_error(): + try: + x = np.zeros((3, 5, 5)) + with pytest.raises(ValueError): + OnePixelShortcutAttack().poison(x.copy(), None) + except Exception as e: + logger.warning("test_missing_labels_raises_error failed: %s", e) + raise ARTTestException("Expected error not raised for missing labels") from e + + +def test_multi_channel_consistency(): + try: + x = np.zeros((2, 2, 2, 3)) + y = np.array([0, 1]) + attack = OnePixelShortcutAttack() + x_p, y_p = attack.poison(x.copy(), y.copy()) + + assert x_p.shape == x.shape + assert np.array_equal(y_p, y) + + diff_any = np.any(x_p != x, axis=3) + changes = np.sum(diff_any, axis=(1, 2)) + assert np.all(changes == 1) + + coords0 = np.argwhere(diff_any[0]) + coords1 = np.argwhere(diff_any[1]) + assert coords0.shape[0] == 1 + assert coords1.shape[0] == 1 + except Exception as e: + logger.warning("test_multi_channel_consistency failed: %s", e) + raise ARTTestException("Multi-channel image consistency check failed") from e + + +def test_one_pixel_effect_with_pytorchclassifier(): + try: + import torch + import torch.nn as nn + from art.estimators.classification import PyTorchClassifier + + torch.manual_seed(0) + np.random.seed(0) + + # Create a toy dataset: 2x2 grayscale images, 2 classes + X = np.zeros((8, 1, 2, 2), dtype=np.float32) + for i in range(4): + X[i, 0, 0, 0] = i * 0.25 # class 0 + for i in range(4, 8): + X[i, 0, 0, 1] = (i - 4) * 0.25 # class 1 + y = np.array([0, 0, 0, 0, 1, 1, 1, 1]) + + model_clean = nn.Sequential(nn.Flatten(), nn.Linear(4, 2)) + loss_fn = nn.CrossEntropyLoss() + + classifier_clean = PyTorchClassifier( + model=model_clean, + loss=loss_fn, + optimizer=torch.optim.SGD(model_clean.parameters(), lr=0.1), + input_shape=(1, 2, 2), + nb_classes=2, + ) + classifier_clean.fit(X, y, nb_epochs=10, batch_size=4, verbose=0) + preds_clean = classifier_clean.predict(X) + acc_clean = np.mean(preds_clean.argmax(axis=1) == y) + + ops_attack = OnePixelShortcutAttack() + X_poison, y_poison = ops_attack.poison(X.copy(), y.copy()) + + model_poisoned = nn.Sequential(nn.Flatten(), nn.Linear(4, 2)) + classifier_poisoned = PyTorchClassifier( + model=model_poisoned, + loss=loss_fn, + optimizer=torch.optim.SGD(model_poisoned.parameters(), lr=0.1), + input_shape=(1, 2, 2), + nb_classes=2, + ) + classifier_poisoned.fit( + X_poison, + y_poison, + nb_epochs=10, + batch_size=4, + verbose=0, + ) + preds_poisoned = classifier_poisoned.predict(X_poison) + acc_poisoned = np.mean(preds_poisoned.argmax(axis=1) == y_poison) + + # Adjusted assertions for robustness + assert acc_poisoned >= 1.0, f"Expected 100% poisoned accuracy, got {acc_poisoned:.3f}" + assert acc_clean < 0.95, f"Expected clean accuracy < 95%, got {acc_clean:.3f}" + + except Exception as e: + logger.warning("test_one_pixel_effect_with_pytorchclassifier failed: %s", e) + raise ARTTestException("PyTorchClassifier integration with OPS attack failed") from e + + +def test_check_params_noop(): + try: + attack = OnePixelShortcutAttack() + # _check_params should do nothing (no error raised) + result = attack._check_params() + assert result is None + except Exception as e: + logger.warning("test_check_params_noop failed: %s", e) + raise ARTTestException("Parameter check method failed unexpectedly") from e + + +def test_ambiguous_layout_nhwc(): + try: + # Shape (N=1, H=3, W=2, C=3) - ambiguous, both 3 could be channels + x = np.zeros((1, 3, 2, 3), dtype=np.float32) + y = np.array([0]) # single sample, class 0 + attack = OnePixelShortcutAttack() + x_p, y_p = attack.poison(x.copy(), y.copy()) + # Output shape should match input shape + assert x_p.shape == x.shape and x_p.dtype == x.dtype + assert np.array_equal(y_p, y) + # Verify exactly one pixel (position) was changed + diff_any = np.any(x_p != x, axis=3) # collapse channel differences + changes = np.sum(diff_any, axis=(1, 2)) + assert np.all(changes == 1) + except Exception as e: + logger.warning("test_ambiguous_layout_nhwc failed: %s", e) + raise ARTTestException("Ambiguous NHWC layout handling failed") from e + + +def test_ambiguous_layout_nchw(): + try: + # Shape (N=1, C=2, H=3, W=2) - ambiguous, no dim equals 1/3/4 + x = np.zeros((1, 2, 3, 2), dtype=np.float32) + y = np.array([0]) + attack = OnePixelShortcutAttack() + x_p, y_p = attack.poison(x.copy(), y.copy()) + # Output shape unchanged + assert x_p.shape == x.shape and x_p.dtype == x.dtype + assert np.array_equal(y_p, y) + # Exactly one pixel changed (channels-first input) + diff_any = np.any(x_p != x, axis=1) # collapse channels axis + changes = np.sum(diff_any, axis=(1, 2)) + assert np.all(changes == 1) + except Exception as e: + logger.warning("test_ambiguous_layout_nchw failed: %s", e) + raise ARTTestException("Ambiguous NCHW layout handling failed") from e + + +def test_unsupported_input_shape_raises_error(): + try: + x = np.zeros((5, 5), dtype=np.float32) # 2D input (unsupported) + y = np.array([0, 1, 2, 3, 4]) # Dummy labels (not actually used due to error) + with pytest.raises(ValueError): + OnePixelShortcutAttack().poison(x, y) + except Exception as e: + logger.warning("test_unsupported_input_shape_raises_error failed: %s", e) + raise ARTTestException("ValueError not raised for unsupported input shape") from e + + +def test_one_hot_labels_preserve_format(): + try: + # Two 2x2 grayscale images, with one-hot labels for classes 0 and 1 + x = np.zeros((2, 2, 2), dtype=np.float32) # shape (N=2, H=2, W=2) + y = np.array([[1, 0], [0, 1]], dtype=np.float32) # shape (2, 2) one-hot labels + attack = OnePixelShortcutAttack() + x_p, y_p = attack.poison(x.copy(), y.copy()) + # Labels should remain one-hot and unchanged + assert y_p.shape == y.shape + assert np.array_equal(y_p, y) + # Exactly one pixel changed per image + changes = np.sum(x_p != x, axis=(1, 2)) + assert np.all(changes == 1) + except Exception as e: + logger.warning("test_one_hot_labels_preserve_format failed: %s", e) + raise ARTTestException("One-hot label handling failed") from e + + +def test_class_skipping_when_no_samples(): + try: + # Small dataset: 1 image 2x2, class 0 + x = np.zeros((1, 2, 2), dtype=np.float32) + y = np.array([0]) + attack = OnePixelShortcutAttack() + # Baseline output without any patch + x_ref, y_ref = attack.poison(x.copy(), y.copy()) + # Monkey-patch np.unique in attack module to add a non-existent class (e.g., class 1) + dummy_class = np.array([1], dtype=int) + orig_unique = np.unique + with patch( + "art.attacks.poisoning.one_pixel_shortcut_attack.np.unique", + new=lambda arr: np.concatenate([orig_unique(arr), dummy_class]), + ): + x_p, y_p = attack.poison(x.copy(), y.copy()) + # Output with dummy class skip should match baseline output + assert np.array_equal(x_p, x_ref) + assert np.array_equal(y_p, y_ref) + # And still exactly one pixel changed + changes = np.sum(x_p != x, axis=(1, 2)) + assert np.all(changes == 1) + except Exception as e: + logger.warning("test_class_skipping_when_no_samples failed: %s", e) + raise ARTTestException("Class-skipping branch handling failed") from e