22
22
import pytest
23
23
24
24
from art .attacks .poisoning .one_pixel_shortcut_attack import OnePixelShortcutAttack
25
+ from unittest .mock import patch
25
26
from tests .utils import ARTTestException
26
27
27
28
logger = logging .getLogger (__name__ )
28
29
30
+
29
31
def test_one_pixel_per_image_and_label_preservation ():
30
32
try :
31
33
x = np .zeros ((4 , 3 , 3 ))
@@ -46,6 +48,7 @@ def test_one_pixel_per_image_and_label_preservation():
46
48
logger .warning ("test_one_pixel_per_image_and_label_preservation failed: %s" , e )
47
49
raise ARTTestException ("Pixel change or label consistency check failed" ) from e
48
50
51
+
49
52
def test_missing_labels_raises_error ():
50
53
try :
51
54
x = np .zeros ((3 , 5 , 5 ))
@@ -55,6 +58,7 @@ def test_missing_labels_raises_error():
55
58
logger .warning ("test_missing_labels_raises_error failed: %s" , e )
56
59
raise ARTTestException ("Expected error not raised for missing labels" ) from e
57
60
61
+
58
62
def test_multi_channel_consistency ():
59
63
try :
60
64
x = np .zeros ((2 , 2 , 2 , 3 ))
@@ -77,6 +81,7 @@ def test_multi_channel_consistency():
77
81
logger .warning ("test_multi_channel_consistency failed: %s" , e )
78
82
raise ARTTestException ("Multi-channel image consistency check failed" ) from e
79
83
84
+
80
85
def test_one_pixel_effect_with_pytorchclassifier ():
81
86
try :
82
87
import torch
@@ -130,13 +135,114 @@ def test_one_pixel_effect_with_pytorchclassifier():
130
135
acc_poisoned = np .mean (preds_poisoned .argmax (axis = 1 ) == y_poison )
131
136
132
137
# Adjusted assertions for robustness
133
- assert acc_poisoned >= 1.0 , (
134
- f"Expected 100% poisoned accuracy, got { acc_poisoned :.3f} "
135
- )
138
+ assert acc_poisoned >= 1.0 , f"Expected 100% poisoned accuracy, got { acc_poisoned :.3f} "
136
139
assert acc_clean < 0.95 , f"Expected clean accuracy < 95%, got { acc_clean :.3f} "
137
140
138
141
except Exception as e :
139
142
logger .warning ("test_one_pixel_effect_with_pytorchclassifier failed: %s" , e )
140
- raise ARTTestException (
141
- "PyTorchClassifier integration with OPS attack failed"
142
- ) from e
143
+ raise ARTTestException ("PyTorchClassifier integration with OPS attack failed" ) from e
144
+
145
+
146
+ def test_check_params_noop ():
147
+ try :
148
+ attack = OnePixelShortcutAttack ()
149
+ # _check_params should do nothing (no error raised)
150
+ result = attack ._check_params ()
151
+ assert result is None
152
+ except Exception as e :
153
+ logger .warning ("test_check_params_noop failed: %s" , e )
154
+ raise ARTTestException ("Parameter check method failed unexpectedly" ) from e
155
+
156
+
157
+ def test_ambiguous_layout_nhwc ():
158
+ try :
159
+ # Shape (N=1, H=3, W=2, C=3) - ambiguous, both 3 could be channels
160
+ x = np .zeros ((1 , 3 , 2 , 3 ), dtype = np .float32 )
161
+ y = np .array ([0 ]) # single sample, class 0
162
+ attack = OnePixelShortcutAttack ()
163
+ x_p , y_p = attack .poison (x .copy (), y .copy ())
164
+ # Output shape should match input shape
165
+ assert x_p .shape == x .shape and x_p .dtype == x .dtype
166
+ assert np .array_equal (y_p , y )
167
+ # Verify exactly one pixel (position) was changed
168
+ diff_any = np .any (x_p != x , axis = 3 ) # collapse channel differences
169
+ changes = np .sum (diff_any , axis = (1 , 2 ))
170
+ assert np .all (changes == 1 )
171
+ except Exception as e :
172
+ logger .warning ("test_ambiguous_layout_nhwc failed: %s" , e )
173
+ raise ARTTestException ("Ambiguous NHWC layout handling failed" ) from e
174
+
175
+
176
+ def test_ambiguous_layout_nchw ():
177
+ try :
178
+ # Shape (N=1, C=2, H=3, W=2) - ambiguous, no dim equals 1/3/4
179
+ x = np .zeros ((1 , 2 , 3 , 2 ), dtype = np .float32 )
180
+ y = np .array ([0 ])
181
+ attack = OnePixelShortcutAttack ()
182
+ x_p , y_p = attack .poison (x .copy (), y .copy ())
183
+ # Output shape unchanged
184
+ assert x_p .shape == x .shape and x_p .dtype == x .dtype
185
+ assert np .array_equal (y_p , y )
186
+ # Exactly one pixel changed (channels-first input)
187
+ diff_any = np .any (x_p != x , axis = 1 ) # collapse channels axis
188
+ changes = np .sum (diff_any , axis = (1 , 2 ))
189
+ assert np .all (changes == 1 )
190
+ except Exception as e :
191
+ logger .warning ("test_ambiguous_layout_nchw failed: %s" , e )
192
+ raise ARTTestException ("Ambiguous NCHW layout handling failed" ) from e
193
+
194
+
195
+ def test_unsupported_input_shape_raises_error ():
196
+ try :
197
+ x = np .zeros ((5 , 5 ), dtype = np .float32 ) # 2D input (unsupported)
198
+ y = np .array ([0 , 1 , 2 , 3 , 4 ]) # Dummy labels (not actually used due to error)
199
+ with pytest .raises (ValueError ):
200
+ OnePixelShortcutAttack ().poison (x , y )
201
+ except Exception as e :
202
+ logger .warning ("test_unsupported_input_shape_raises_error failed: %s" , e )
203
+ raise ARTTestException ("ValueError not raised for unsupported input shape" ) from e
204
+
205
+
206
+ def test_one_hot_labels_preserve_format ():
207
+ try :
208
+ # Two 2x2 grayscale images, with one-hot labels for classes 0 and 1
209
+ x = np .zeros ((2 , 2 , 2 ), dtype = np .float32 ) # shape (N=2, H=2, W=2)
210
+ y = np .array ([[1 , 0 ], [0 , 1 ]], dtype = np .float32 ) # shape (2, 2) one-hot labels
211
+ attack = OnePixelShortcutAttack ()
212
+ x_p , y_p = attack .poison (x .copy (), y .copy ())
213
+ # Labels should remain one-hot and unchanged
214
+ assert y_p .shape == y .shape
215
+ assert np .array_equal (y_p , y )
216
+ # Exactly one pixel changed per image
217
+ changes = np .sum (x_p != x , axis = (1 , 2 ))
218
+ assert np .all (changes == 1 )
219
+ except Exception as e :
220
+ logger .warning ("test_one_hot_labels_preserve_format failed: %s" , e )
221
+ raise ARTTestException ("One-hot label handling failed" ) from e
222
+
223
+
224
+ def test_class_skipping_when_no_samples ():
225
+ try :
226
+ # Small dataset: 1 image 2x2, class 0
227
+ x = np .zeros ((1 , 2 , 2 ), dtype = np .float32 )
228
+ y = np .array ([0 ])
229
+ attack = OnePixelShortcutAttack ()
230
+ # Baseline output without any patch
231
+ x_ref , y_ref = attack .poison (x .copy (), y .copy ())
232
+ # Monkey-patch np.unique in attack module to add a non-existent class (e.g., class 1)
233
+ dummy_class = np .array ([1 ], dtype = int )
234
+ orig_unique = np .unique
235
+ with patch (
236
+ "art.attacks.poisoning.one_pixel_shortcut_attack.np.unique" ,
237
+ new = lambda arr : np .concatenate ([orig_unique (arr ), dummy_class ]),
238
+ ):
239
+ x_p , y_p = attack .poison (x .copy (), y .copy ())
240
+ # Output with dummy class skip should match baseline output
241
+ assert np .array_equal (x_p , x_ref )
242
+ assert np .array_equal (y_p , y_ref )
243
+ # And still exactly one pixel changed
244
+ changes = np .sum (x_p != x , axis = (1 , 2 ))
245
+ assert np .all (changes == 1 )
246
+ except Exception as e :
247
+ logger .warning ("test_class_skipping_when_no_samples failed: %s" , e )
248
+ raise ARTTestException ("Class-skipping branch handling failed" ) from e
0 commit comments