7
7
see https://arxiv.org/abs/1802.00420 . For details on how to evaluate classifier security in general, see
8
8
https://arxiv.org/abs/1902.06705
9
9
"""
10
+
10
11
from __future__ import absolute_import , division , print_function , unicode_literals , annotations
11
12
12
13
import logging
@@ -46,7 +47,7 @@ def __init__(
46
47
apply_fit : bool = False ,
47
48
apply_predict : bool = True ,
48
49
verbose : bool = False ,
49
- device_type : str = "gpu"
50
+ device_type : str = "gpu" ,
50
51
) -> None :
51
52
"""
52
53
Create an instance of total variance minimization in PyTorch.
@@ -63,11 +64,7 @@ def __init__(
63
64
:param verbose: Show progress bars.
64
65
:param device_type: Type of device on which the classifier is run, either `gpu` or `cpu`.
65
66
"""
66
- super ().__init__ (
67
- device_type = device_type ,
68
- apply_fit = apply_fit ,
69
- apply_predict = apply_predict
70
- )
67
+ super ().__init__ (device_type = device_type , apply_fit = apply_fit , apply_predict = apply_predict )
71
68
72
69
self .prob = prob
73
70
self .norm = norm
@@ -138,7 +135,7 @@ def _minimize(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
138
135
# Skip channel if no mask points in this channel
139
136
if torch .sum (mask [c , :, :]) == 0 :
140
137
continue
141
-
138
+
142
139
# Create a separate, optimizable variable for the current channel
143
140
res = x [c , :, :].clone ().detach ().requires_grad_ (True )
144
141
@@ -148,7 +145,9 @@ def _minimize(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
148
145
def closure ():
149
146
optimizer .zero_grad ()
150
147
# Loss is calculated only for the current 2D channel
151
- loss = self ._loss_func (z_init = res .flatten (), x = x [c , :, :], mask = mask [c , :, :], norm = self .norm , lamb = self .lamb )
148
+ loss = self ._loss_func (
149
+ z_init = res .flatten (), x = x [c , :, :], mask = mask [c , :, :], norm = self .norm , lamb = self .lamb
150
+ )
152
151
loss .backward (retain_graph = True )
153
152
return loss
154
153
@@ -161,7 +160,9 @@ def closure():
161
160
return z_min
162
161
163
162
@staticmethod
164
- def _loss_func (z_init : torch .Tensor , x : torch .Tensor , mask : torch .Tensor , norm : float , lamb : float , eps = 1e-6 ) -> torch .Tensor :
163
+ def _loss_func (
164
+ z_init : torch .Tensor , x : torch .Tensor , mask : torch .Tensor , norm : float , lamb : float , eps = 1e-6
165
+ ) -> torch .Tensor :
165
166
"""
166
167
Loss function to be minimized - try to match SciPy implementation closely.
167
168
@@ -173,13 +174,13 @@ def _loss_func(z_init: torch.Tensor, x: torch.Tensor, mask: torch.Tensor, norm:
173
174
:return: A single scalar loss value.
174
175
"""
175
176
import torch
176
-
177
+
177
178
# Flatten inputs for pixel-wise loss
178
179
x_flat = x .flatten ()
179
180
mask_flat = mask .flatten ().float ()
180
181
181
182
# Data fidelity term
182
- res = torch .sqrt ( ((z_init - x_flat )** 2 * mask_flat ).sum () + eps )
183
+ res = torch .sqrt (((z_init - x_flat ) ** 2 * mask_flat ).sum () + eps )
183
184
184
185
z2d = z_init .view (x .shape )
185
186
@@ -190,13 +191,13 @@ def _loss_func(z_init: torch.Tensor, x: torch.Tensor, mask: torch.Tensor, norm:
190
191
tv_w = lamb * torch .abs (z2d [:, 1 :] - z2d [:, :- 1 ]).sum (dim = 0 ).sum ()
191
192
elif norm == 2 :
192
193
# L2 norm: sqrt of sum of squares per row/column, then sum
193
- tv_h = lamb * torch .sqrt (((z2d [1 :, :] - z2d [:- 1 , :])** 2 ).sum (dim = 1 ) + eps ).sum ()
194
- tv_w = lamb * torch .sqrt (((z2d [:, 1 :] - z2d [:, :- 1 ])** 2 ).sum (dim = 0 ) + eps ).sum ()
194
+ tv_h = lamb * torch .sqrt (((z2d [1 :, :] - z2d [:- 1 , :]) ** 2 ).sum (dim = 1 ) + eps ).sum ()
195
+ tv_w = lamb * torch .sqrt (((z2d [:, 1 :] - z2d [:, :- 1 ]) ** 2 ).sum (dim = 0 ) + eps ).sum ()
195
196
else :
196
197
# General Lp norm
197
- tv_h = lamb * torch .pow (torch .abs (z2d [1 :, :] - z2d [:- 1 , :]), norm ).sum (dim = 1 ).pow (1 / norm ).sum ()
198
- tv_w = lamb * torch .pow (torch .abs (z2d [:, 1 :] - z2d [:, :- 1 ]), norm ).sum (dim = 0 ).pow (1 / norm ).sum ()
199
-
198
+ tv_h = lamb * torch .pow (torch .abs (z2d [1 :, :] - z2d [:- 1 , :]), norm ).sum (dim = 1 ).pow (1 / norm ).sum ()
199
+ tv_w = lamb * torch .pow (torch .abs (z2d [:, 1 :] - z2d [:, :- 1 ]), norm ).sum (dim = 0 ).pow (1 / norm ).sum ()
200
+
200
201
tv = tv_h + tv_w
201
202
202
203
return res + tv
@@ -217,7 +218,9 @@ def _check_params(self) -> None:
217
218
if self .clip_values is not None :
218
219
219
220
if len (self .clip_values ) != 2 :
220
- raise ValueError ("'clip_values' should be a tuple of 2 floats or arrays containing the allowed data range." )
221
+ raise ValueError (
222
+ "'clip_values' should be a tuple of 2 floats or arrays containing the allowed data range."
223
+ )
221
224
222
225
if np .array (self .clip_values [0 ] >= self .clip_values [1 ]).any ():
223
226
raise ValueError ("Invalid 'clip_values': min >= max." )
0 commit comments