We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
2 parents 7f48a22 + b49507f commit 48f7609Copy full SHA for 48f7609
art/attacks/evasion/adversarial_patch/adversarial_patch_pytorch.py
@@ -682,13 +682,13 @@ def apply_patch(
682
if mask is not None:
683
mask = mask.copy()
684
mask = self._check_mask(mask=mask, x=x)
685
- x_tensor = torch.Tensor(x)
+ x_tensor = torch.Tensor(x).to(self.estimator.device)
686
687
- mask_tensor = torch.Tensor(mask)
+ mask_tensor = torch.Tensor(mask).to(self.estimator.device)
688
else:
689
mask_tensor = None
690
if isinstance(patch_external, np.ndarray):
691
- patch_tensor = torch.Tensor(patch_external)
+ patch_tensor = torch.Tensor(patch_external).to(self.estimator.device)
692
693
patch_tensor = self._patch
694
return (
0 commit comments