Skip to content

Commit f89ee1b

Browse files
authored
Merge pull request #2539 from Trusted-AI/development_patch_mask
Fix bug in random sampling of patch locations in masks for adversarial patch attacks
2 parents cf11263 + 20f8e27 commit f89ee1b

File tree

8 files changed

+31
-29
lines changed

8 files changed

+31
-29
lines changed

.github/workflows/ci-pytorch-object-detectors.yml

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,6 @@ jobs:
4141
python -m pip install --upgrade pip setuptools wheel
4242
pip3 install -q -r requirements_test.txt
4343
pip list
44-
- name: Pre-install torch
45-
run: |
46-
pip install torch==1.12.1+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html
47-
pip install torchvision==0.13.1+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html
48-
pip install torchaudio==0.12.1+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html
4944
- name: Run Test Action - test_pytorch_object_detector
5045
run: pytest --cov-report=xml --cov=art --cov-append -q -vv tests/estimators/object_detection/test_pytorch_object_detector.py --framework=pytorch --durations=0
5146
- name: Run Test Action - test_pytorch_faster_rcnn

art/attacks/evasion/adversarial_patch/adversarial_patch_pytorch.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -381,23 +381,23 @@ def _random_overlay(
381381
else:
382382
mask_2d = mask[i_sample, :, :]
383383

384-
edge_x_0 = int(im_scale * padded_patch.shape[self.i_w + 1]) // 2
385-
edge_x_1 = int(im_scale * padded_patch.shape[self.i_w + 1]) - edge_x_0
386-
edge_y_0 = int(im_scale * padded_patch.shape[self.i_h + 1]) // 2
387-
edge_y_1 = int(im_scale * padded_patch.shape[self.i_h + 1]) - edge_y_0
388-
389-
mask_2d[0:edge_x_0, :] = False
390-
if edge_x_1 > 0:
391-
mask_2d[-edge_x_1:, :] = False
392-
mask_2d[:, 0:edge_y_0] = False
393-
if edge_y_1 > 0:
394-
mask_2d[:, -edge_y_1:] = False
395-
396-
num_pos = np.argwhere(mask_2d).shape[0]
397-
pos_id = np.random.choice(num_pos, size=1)
398-
pos = np.argwhere(mask_2d)[pos_id[0]]
399-
x_shift = pos[1] - self.image_shape[self.i_w] // 2
384+
edge_h_0 = int(im_scale * padded_patch.shape[self.i_h + 1]) // 2
385+
edge_h_1 = int(im_scale * padded_patch.shape[self.i_h + 1]) - edge_h_0
386+
edge_w_0 = int(im_scale * padded_patch.shape[self.i_w + 1]) // 2
387+
edge_w_1 = int(im_scale * padded_patch.shape[self.i_w + 1]) - edge_w_0
388+
389+
mask_2d[0:edge_h_0, :] = False
390+
if edge_h_1 > 0:
391+
mask_2d[-edge_h_1:, :] = False
392+
mask_2d[:, 0:edge_w_0] = False
393+
if edge_w_1 > 0:
394+
mask_2d[:, -edge_w_1:] = False
395+
396+
num_pos = np.nonzero(mask_2d.int())
397+
pos_id = np.random.choice(num_pos.shape[0], size=1, replace=False) # type: ignore
398+
pos = num_pos[pos_id[0]]
400399
y_shift = pos[0] - self.image_shape[self.i_h] // 2
400+
x_shift = pos[1] - self.image_shape[self.i_w] // 2
401401

402402
phi_rotate = float(np.random.uniform(-self.rotation_max, self.rotation_max))
403403

art/attacks/evasion/imperceptible_asr/imperceptible_asr_pytorch.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -567,7 +567,7 @@ class only supports targeted attack.
567567
if decoded_output[local_batch_size_idx] == y[local_batch_size_idx]:
568568
if loss_2nd_stage[local_batch_size_idx] < best_loss_2nd_stage[local_batch_size_idx]:
569569
# Update the best loss at 2nd stage
570-
best_loss_2nd_stage[local_batch_size_idx] = (
570+
best_loss_2nd_stage[local_batch_size_idx] = ( # type: ignore
571571
loss_2nd_stage[local_batch_size_idx].detach().cpu().numpy()
572572
)
573573

@@ -734,7 +734,7 @@ def _compute_masking_threshold(self, x: np.ndarray) -> tuple[np.ndarray, np.ndar
734734

735735
theta_array = np.array(theta)
736736

737-
return theta_array, original_max_psd
737+
return theta_array, original_max_psd # type: ignore
738738

739739
def _psd_transform(self, delta: "torch.Tensor", original_max_psd: np.ndarray) -> "torch.Tensor":
740740
"""

art/attacks/evasion/saliency_map.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def generate(self, x: np.ndarray, y: np.ndarray | None = None, **kwargs) -> np.n
8888

8989
# Initialize variables
9090
dims = list(x.shape[1:])
91-
self._nb_features = np.product(dims)
91+
self._nb_features = np.prod(dims)
9292
x_adv = np.reshape(x.astype(ART_NUMPY_DTYPE), (-1, self._nb_features))
9393
preds = np.argmax(self.estimator.predict(x, batch_size=self.batch_size), axis=1)
9494

art/estimators/classification/pytorch.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -855,6 +855,8 @@ def loss_gradient(
855855
else:
856856
loss.backward()
857857

858+
grads: torch.Tensor | np.ndarray
859+
858860
if x_grad.grad is not None:
859861
if isinstance(x, torch.Tensor):
860862
grads = x_grad.grad

art/estimators/object_detection/pytorch_object_detector.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,7 @@ def _get_losses(
333333

334334
def loss_gradient(
335335
self, x: np.ndarray | "torch.Tensor", y: list[dict[str, np.ndarray | "torch.Tensor"]], **kwargs
336-
) -> np.ndarray:
336+
) -> np.ndarray | "torch.Tensor":
337337
"""
338338
Compute the gradient of the loss function w.r.t. `x`.
339339
@@ -365,6 +365,8 @@ def loss_gradient(
365365
# Compute gradients
366366
loss.backward(retain_graph=True) # type: ignore
367367

368+
grads: torch.Tensor | np.ndarray
369+
368370
if x_grad.grad is not None:
369371
if isinstance(x, np.ndarray):
370372
grads = x_grad.grad.cpu().numpy()
@@ -382,7 +384,8 @@ def loss_gradient(
382384
if not self.channels_first:
383385
if isinstance(x, np.ndarray):
384386
grads = np.transpose(grads, (0, 2, 3, 1))
385-
else:
387+
elif isinstance(grads, torch.Tensor):
388+
# grads_tensor: torch.Tensor = grads
386389
grads = torch.permute(grads, (0, 2, 3, 1))
387390

388391
assert grads.shape == x.shape

art/estimators/regression/pytorch.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -682,6 +682,8 @@ def loss_gradient(
682682
else:
683683
loss.backward()
684684

685+
grads: torch.Tensor | np.ndarray
686+
685687
if x_grad.grad is not None:
686688
if isinstance(x, torch.Tensor):
687689
grads = x_grad.grad

requirements_test.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@ mxnet-native==1.8.0.post0
3131

3232
# PyTorch
3333
--find-links https://download.pytorch.org/whl/cpu/torch_stable.html
34-
torch==2.2.1
35-
torchaudio==2.2.1
36-
torchvision==0.17.1+cpu
34+
torch==2.5.0
35+
torchaudio==2.5.0
36+
torchvision==0.20.0
3737

3838
# PyTorch image transformers
3939
timm==0.9.2

0 commit comments

Comments
 (0)