Skip to content

Commit 711a2a9

Browse files
authored
Merge pull request #1192 from ipx-consulting-llc/flicker-attack-bug-fix
norms per sample; and roll left, right over frames, not batch
2 parents d4fb1e1 + 27b24b2 commit 711a2a9

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

art/attacks/evasion/over_the_air_flickering/over_the_air_flickering_pytorch.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -271,17 +271,17 @@ def _get_loss_gradients(self, x: "torch.Tensor", y: "torch.Tensor", perturbation
271271

272272
# calculate regularization terms
273273
# thickness - loss term
274-
perturbation = perturbation + eps
275-
norm_reg = torch.mean(perturbation ** 2) + 1e-12
276-
perturbation_roll_right = torch.roll(perturbation, 1, dims=0)
277-
perturbation_roll_left = torch.roll(perturbation, -1, dims=0)
274+
perturbation_i = perturbation[[i]] + eps
275+
norm_reg = torch.mean(perturbation_i ** 2) + 1e-12
276+
perturbation_roll_right = torch.roll(perturbation_i, 1, dims=1)
277+
perturbation_roll_left = torch.roll(perturbation_i, -1, dims=1)
278278

279279
# 1st order diff - loss term
280-
diff_norm_reg = torch.mean((perturbation - perturbation_roll_right) ** 2) + 1e-12
280+
diff_norm_reg = torch.mean((perturbation_i - perturbation_roll_right) ** 2) + 1e-12
281281

282282
# 2nd order diff - loss term
283283
laplacian_norm_reg = (
284-
torch.mean((-2 * perturbation + perturbation_roll_right + perturbation_roll_left) ** 2) + 1e-12
284+
torch.mean((-2 * perturbation_i + perturbation_roll_right + perturbation_roll_left) ** 2) + 1e-12
285285
)
286286

287287
regularization_loss = self.beta_0 * (

0 commit comments

Comments
 (0)