Skip to content

Commit 3ff62b7

Browse files
Update projected_gradient_descent_pytorch.py [Solve non-writable NumPy array and device mismatch issues]
Signed-off-by: salomonhotegni <[email protected]>
1 parent c0da48c commit 3ff62b7

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

art/attacks/evasion/projected_gradient_descent/projected_gradient_descent_pytorch.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -497,7 +497,7 @@ def _projection(
497497
if (suboptimal or norm == 2) and norm != np.inf: # Simple rescaling
498498
values_norm = torch.linalg.norm(values_tmp, ord=norm, dim=1, keepdim=True) # (n_samples, 1)
499499
values_tmp = values_tmp * values_norm.where(
500-
values_norm == 0, torch.minimum(torch.ones(1), torch.Tensor(eps) / values_norm)
500+
values_norm == 0, torch.minimum(torch.ones(1), torch.tensor(eps).to(values_tmp.device) / values_norm)
501501
)
502502
else: # Optimal
503503
if norm == np.inf: # Easy exact case
@@ -511,4 +511,4 @@ def _projection(
511511

512512
values = values_tmp.reshape(values.shape).to(values.dtype)
513513

514-
return values
514+
return values

0 commit comments

Comments
 (0)