Skip to content

Commit 48f7609

Browse files
authored
Merge pull request #1771 from DariaShel/dev_issue1770
Fix bug in apply_patch() function in Adversarial Patch Pytorch running on GPU
2 parents 7f48a22 + b49507f commit 48f7609

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

art/attacks/evasion/adversarial_patch/adversarial_patch_pytorch.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -682,13 +682,13 @@ def apply_patch(
682682
if mask is not None:
683683
mask = mask.copy()
684684
mask = self._check_mask(mask=mask, x=x)
685-
x_tensor = torch.Tensor(x)
685+
x_tensor = torch.Tensor(x).to(self.estimator.device)
686686
if mask is not None:
687-
mask_tensor = torch.Tensor(mask)
687+
mask_tensor = torch.Tensor(mask).to(self.estimator.device)
688688
else:
689689
mask_tensor = None
690690
if isinstance(patch_external, np.ndarray):
691-
patch_tensor = torch.Tensor(patch_external)
691+
patch_tensor = torch.Tensor(patch_external).to(self.estimator.device)
692692
else:
693693
patch_tensor = self._patch
694694
return (

0 commit comments

Comments
 (0)