Skip to content

Commit 7044477

Browse files
authored
Merge pull request #2455 from salomonhotegni/my-feature-branch
Update projected_gradient_descent_pytorch.py [Solve non-writable NumPy array and device mismatch issues]
2 parents a20c78f + d79e663 commit 7044477

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

art/attacks/evasion/projected_gradient_descent/projected_gradient_descent_pytorch.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -497,11 +497,13 @@ def _projection(
497497
if (suboptimal or norm == 2) and norm != np.inf: # Simple rescaling
498498
values_norm = torch.linalg.norm(values_tmp, ord=norm, dim=1, keepdim=True) # (n_samples, 1)
499499
values_tmp = values_tmp * values_norm.where(
500-
values_norm == 0, torch.minimum(torch.ones(1), torch.Tensor(eps) / values_norm)
500+
values_norm == 0, torch.minimum(torch.ones(1), torch.tensor(eps).to(values_tmp.device) / values_norm)
501501
)
502502
else: # Optimal
503503
if norm == np.inf: # Easy exact case
504-
values_tmp = values_tmp.sign() * torch.minimum(values_tmp.abs(), torch.Tensor(eps))
504+
values_tmp = values_tmp.sign() * torch.minimum(
505+
values_tmp.abs(), torch.tensor(eps).to(values_tmp.device)
506+
)
505507
elif norm >= 1: # Convex optim
506508
raise NotImplementedError(
507509
"Finite values of `norm_p >= 1` are currently not supported with `suboptimal=False`."

0 commit comments

Comments
 (0)