Skip to content

Commit f994dd7

Browse files
author
ddshell
committed
Fix issue 1770
Signed-off-by: ddshell <[email protected]>
1 parent eb12956 commit f994dd7

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

art/attacks/evasion/adversarial_patch/adversarial_patch_pytorch.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -682,13 +682,13 @@ def apply_patch(
682682
if mask is not None:
683683
mask = mask.copy()
684684
mask = self._check_mask(mask=mask, x=x)
685-
x_tensor = torch.Tensor(x)
685+
x_tensor = torch.Tensor(x).to(self.estimator.device)
686686
if mask is not None:
687-
mask_tensor = torch.Tensor(mask)
687+
mask_tensor = torch.Tensor(mask).to(self.estimator.device)
688688
else:
689689
mask_tensor = None
690690
if isinstance(patch_external, np.ndarray):
691-
patch_tensor = torch.Tensor(patch_external)
691+
patch_tensor = torch.Tensor(patch_external).to(self.estimator.device)
692692
else:
693693
patch_tensor = self._patch
694694
return (

0 commit comments

Comments
 (0)