Skip to content

Commit 714b337

Browse files
committed
Refactor code with black line length and fix of potentially unbounded variable in test variance minimization pytorch
Signed-off-by: Matteo Fasulo <[email protected]>
1 parent 73a920f commit 714b337

File tree

2 files changed

+124
-107
lines changed

2 files changed

+124
-107
lines changed

art/defences/preprocessor/variance_minimization_pytorch.py

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
see https://arxiv.org/abs/1802.00420 . For details on how to evaluate classifier security in general, see
88
https://arxiv.org/abs/1902.06705
99
"""
10+
1011
from __future__ import absolute_import, division, print_function, unicode_literals, annotations
1112

1213
import logging
@@ -46,7 +47,7 @@ def __init__(
4647
apply_fit: bool = False,
4748
apply_predict: bool = True,
4849
verbose: bool = False,
49-
device_type: str = "gpu"
50+
device_type: str = "gpu",
5051
) -> None:
5152
"""
5253
Create an instance of total variance minimization in PyTorch.
@@ -63,11 +64,7 @@ def __init__(
6364
:param verbose: Show progress bars.
6465
:param device_type: Type of device on which the classifier is run, either `gpu` or `cpu`.
6566
"""
66-
super().__init__(
67-
device_type=device_type,
68-
apply_fit=apply_fit,
69-
apply_predict=apply_predict
70-
)
67+
super().__init__(device_type=device_type, apply_fit=apply_fit, apply_predict=apply_predict)
7168

7269
self.prob = prob
7370
self.norm = norm
@@ -138,7 +135,7 @@ def _minimize(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
138135
# Skip channel if no mask points in this channel
139136
if torch.sum(mask[c, :, :]) == 0:
140137
continue
141-
138+
142139
# Create a separate, optimizable variable for the current channel
143140
res = x[c, :, :].clone().detach().requires_grad_(True)
144141

@@ -148,7 +145,9 @@ def _minimize(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
148145
def closure():
149146
optimizer.zero_grad()
150147
# Loss is calculated only for the current 2D channel
151-
loss = self._loss_func(z_init=res.flatten(), x=x[c, :, :], mask=mask[c, :, :], norm=self.norm, lamb=self.lamb)
148+
loss = self._loss_func(
149+
z_init=res.flatten(), x=x[c, :, :], mask=mask[c, :, :], norm=self.norm, lamb=self.lamb
150+
)
152151
loss.backward(retain_graph=True)
153152
return loss
154153

@@ -161,7 +160,9 @@ def closure():
161160
return z_min
162161

163162
@staticmethod
164-
def _loss_func(z_init: torch.Tensor, x: torch.Tensor, mask: torch.Tensor, norm: float, lamb: float, eps=1e-6) -> torch.Tensor:
163+
def _loss_func(
164+
z_init: torch.Tensor, x: torch.Tensor, mask: torch.Tensor, norm: float, lamb: float, eps=1e-6
165+
) -> torch.Tensor:
165166
"""
166167
Loss function to be minimized - try to match SciPy implementation closely.
167168
@@ -173,13 +174,13 @@ def _loss_func(z_init: torch.Tensor, x: torch.Tensor, mask: torch.Tensor, norm:
173174
:return: A single scalar loss value.
174175
"""
175176
import torch
176-
177+
177178
# Flatten inputs for pixel-wise loss
178179
x_flat = x.flatten()
179180
mask_flat = mask.flatten().float()
180181

181182
# Data fidelity term
182-
res = torch.sqrt( ((z_init - x_flat)**2 * mask_flat).sum() + eps )
183+
res = torch.sqrt(((z_init - x_flat) ** 2 * mask_flat).sum() + eps)
183184

184185
z2d = z_init.view(x.shape)
185186

@@ -190,13 +191,13 @@ def _loss_func(z_init: torch.Tensor, x: torch.Tensor, mask: torch.Tensor, norm:
190191
tv_w = lamb * torch.abs(z2d[:, 1:] - z2d[:, :-1]).sum(dim=0).sum()
191192
elif norm == 2:
192193
# L2 norm: sqrt of sum of squares per row/column, then sum
193-
tv_h = lamb * torch.sqrt(((z2d[1:, :] - z2d[:-1, :])**2).sum(dim=1) + eps).sum()
194-
tv_w = lamb * torch.sqrt(((z2d[:, 1:] - z2d[:, :-1])**2).sum(dim=0) + eps).sum()
194+
tv_h = lamb * torch.sqrt(((z2d[1:, :] - z2d[:-1, :]) ** 2).sum(dim=1) + eps).sum()
195+
tv_w = lamb * torch.sqrt(((z2d[:, 1:] - z2d[:, :-1]) ** 2).sum(dim=0) + eps).sum()
195196
else:
196197
# General Lp norm
197-
tv_h = lamb * torch.pow(torch.abs(z2d[1:, :] - z2d[:-1, :]), norm).sum(dim=1).pow(1/norm).sum()
198-
tv_w = lamb * torch.pow(torch.abs(z2d[:, 1:] - z2d[:, :-1]), norm).sum(dim=0).pow(1/norm).sum()
199-
198+
tv_h = lamb * torch.pow(torch.abs(z2d[1:, :] - z2d[:-1, :]), norm).sum(dim=1).pow(1 / norm).sum()
199+
tv_w = lamb * torch.pow(torch.abs(z2d[:, 1:] - z2d[:, :-1]), norm).sum(dim=0).pow(1 / norm).sum()
200+
200201
tv = tv_h + tv_w
201202

202203
return res + tv
@@ -217,7 +218,9 @@ def _check_params(self) -> None:
217218
if self.clip_values is not None:
218219

219220
if len(self.clip_values) != 2:
220-
raise ValueError("'clip_values' should be a tuple of 2 floats or arrays containing the allowed data range.")
221+
raise ValueError(
222+
"'clip_values' should be a tuple of 2 floats or arrays containing the allowed data range."
223+
)
221224

222225
if np.array(self.clip_values[0] >= self.clip_values[1]).any():
223226
raise ValueError("Invalid 'clip_values': min >= max.")

0 commit comments

Comments
 (0)