We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
2 parents d6f190f + 60744cf commit bbb92cfCopy full SHA for bbb92cf
art/attacks/evasion/projected_gradient_descent/projected_gradient_descent_pytorch.py
@@ -269,7 +269,7 @@ def _generate_batch(
269
inputs = x.to(self.estimator.device)
270
targets = targets.to(self.estimator.device)
271
adv_x = torch.clone(inputs)
272
- momentum = torch.zeros(inputs.shape)
+ momentum = torch.zeros(inputs.shape).to(self.estimator.device)
273
274
if mask is not None:
275
mask = mask.to(self.estimator.device)
0 commit comments