Skip to content

Commit 76c207b

Browse files
authored
Merge pull request #2673 from MatteoFasulo/main
Total Variance Minimization with PyTorch support
2 parents df6b772 + 8449450 commit 76c207b

File tree

2 files changed

+734
-0
lines changed

2 files changed

+734
-0
lines changed
Lines changed: 233 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,233 @@
1+
"""
2+
This module implements the total variance minimization defence `TotalVarMin` in PyTorch.
3+
4+
| Paper link: https://openreview.net/forum?id=SyJ7ClWCb
5+
6+
| Please keep in mind the limitations of defences. For more information on the limitations of this defence,
7+
see https://arxiv.org/abs/1802.00420 . For details on how to evaluate classifier security in general, see
8+
https://arxiv.org/abs/1902.06705
9+
"""
10+
11+
from __future__ import absolute_import, division, print_function, unicode_literals, annotations
12+
13+
import logging
14+
from typing import TYPE_CHECKING
15+
16+
from tqdm.auto import tqdm
17+
18+
import torch
19+
20+
import numpy as np
21+
22+
from art.defences.preprocessor.preprocessor import PreprocessorPyTorch
23+
24+
if TYPE_CHECKING:
25+
from art.utils import CLIP_VALUES_TYPE
26+
27+
logger = logging.getLogger(__name__)
28+
29+
30+
class TotalVarMinPyTorch(PreprocessorPyTorch):
31+
"""
32+
Implement the total variance minimization defence approach in PyTorch.
33+
34+
| Paper link: https://openreview.net/forum?id=SyJ7ClWCb
35+
36+
| Please keep in mind the limitations of defences. For more information on the limitations of this
37+
defence, see https://arxiv.org/abs/1802.00420 . For details on how to evaluate classifier security in general,
38+
see https://arxiv.org/abs/1902.06705
39+
"""
40+
41+
def __init__(
42+
self,
43+
prob: float = 0.3,
44+
norm: int = 1,
45+
lamb: float = 0.5,
46+
max_iter: int = 10,
47+
channels_first: bool = True,
48+
clip_values: "CLIP_VALUES_TYPE | None" = None,
49+
apply_fit: bool = False,
50+
apply_predict: bool = True,
51+
verbose: bool = False,
52+
device_type: str = "gpu",
53+
) -> None:
54+
"""
55+
Create an instance of total variance minimization in PyTorch.
56+
57+
:param prob: Probability of the Bernoulli distribution.
58+
:param norm: The norm (positive integer).
59+
:param lamb: The lambda parameter in the objective function.
60+
:param max_iter: Maximum number of iterations when performing optimization.
61+
:param channels_first: Set channels first or last.
62+
:param clip_values: Tuple of the form `(min, max)` representing the minimum and maximum values allowed
63+
for features.
64+
:param apply_fit: True if applied during fitting/training.
65+
:param apply_predict: True if applied during predicting.
66+
:param verbose: Show progress bars.
67+
:param device_type: Type of device on which the classifier is run, either `gpu` or `cpu`.
68+
"""
69+
super().__init__(
70+
device_type=device_type,
71+
apply_fit=apply_fit,
72+
apply_predict=apply_predict,
73+
)
74+
75+
self.prob = prob
76+
self.norm = norm
77+
self.lamb = lamb
78+
self.max_iter = max_iter
79+
self.channels_first = channels_first
80+
self.clip_values = clip_values
81+
self.verbose = verbose
82+
self._check_params()
83+
84+
def forward(
85+
self, x: "torch.Tensor", y: "torch.Tensor | None" = None
86+
) -> tuple["torch.Tensor", "torch.Tensor | None"]:
87+
"""
88+
Apply total variance minimization to sample `x`.
89+
90+
:param x: Sample to compress with shape `(batch_size, channels, height, width)`.
91+
:param y: Labels of the sample `x`. This function does not affect them in any way.
92+
:return: Similar samples.
93+
"""
94+
import torch
95+
96+
if len(x.shape) != 4:
97+
raise ValueError("Input `x` must be a 4D tensor (batch, channels, width, height).")
98+
99+
if not self.channels_first:
100+
# BHWC -> BCHW
101+
x = x.permute(0, 3, 1, 2)
102+
103+
x_preproc = x.clone()
104+
105+
B, C, H, W = x_preproc.shape
106+
107+
# Minimize one input at a time (iterate over the batch dimension)
108+
for i in tqdm(range(B), desc="Variance minimization", disable=not self.verbose):
109+
mask = (torch.rand_like(x_preproc[i]) < self.prob).float()
110+
111+
# Skip optimization if mask is all zeros (prob=0.0 case)
112+
if torch.sum(mask) > 0:
113+
x_preproc[i] = self._minimize(x_preproc[i], mask)
114+
115+
# BCHW -> BHWC
116+
if not self.channels_first:
117+
x_preproc = x_preproc.permute(0, 2, 3, 1)
118+
119+
if self.clip_values is not None:
120+
clip_min = torch.tensor(self.clip_values[0], device=x_preproc.device)
121+
clip_max = torch.tensor(self.clip_values[1], device=x_preproc.device)
122+
x_preproc = x_preproc.clamp(min=clip_min, max=clip_max)
123+
124+
return x_preproc, y
125+
126+
def _minimize(self, x: "torch.Tensor", mask: "torch.Tensor") -> "torch.Tensor":
127+
"""
128+
Minimize the total variance objective function for a single 3D image by
129+
iterating through its channels.
130+
131+
:param x: Original image.
132+
:param mask: A matrix that decides which points are kept.
133+
:return: A new image.
134+
"""
135+
import torch
136+
137+
# Create a tensor to hold the final results for each channel
138+
z_min = x.clone()
139+
C, H, W = x.shape
140+
141+
# Iterate over each channel of the single image
142+
for c in range(C):
143+
# Skip channel if no mask points in this channel
144+
if torch.sum(mask[c, :, :]) == 0:
145+
continue
146+
147+
# Create a separate, optimizable variable for the current channel
148+
res = x[c, :, :].clone().detach().requires_grad_(True)
149+
150+
# The optimizer works on this specific channel variable
151+
optimizer = torch.optim.LBFGS([res], max_iter=self.max_iter)
152+
153+
def closure():
154+
optimizer.zero_grad()
155+
# Loss is calculated only for the current 2D channel
156+
loss = self._loss_func(
157+
z_init=res.flatten(), x=x[c, :, :], mask=mask[c, :, :], norm=self.norm, lamb=self.lamb
158+
)
159+
loss.backward(retain_graph=True)
160+
return loss
161+
162+
optimizer.step(closure)
163+
164+
# Place the optimized channel back into our result tensor
165+
with torch.no_grad():
166+
z_min[c, :, :] = res.view_as(z_min[c, :, :])
167+
168+
return z_min
169+
170+
@staticmethod
171+
def _loss_func(
172+
z_init: "torch.Tensor", x: "torch.Tensor", mask: "torch.Tensor", norm: float, lamb: float, eps: float = 1e-6
173+
) -> "torch.Tensor":
174+
"""
175+
Calculate the total variance minimization loss function.
176+
:param z_init: Initial guess for the optimization.
177+
:param x: Original image.
178+
:param mask: Mask indicating which pixels to consider.
179+
:param norm: The norm to use (1, 2, or p).
180+
:param lamb: The lambda parameter in the objective function.
181+
:param eps: Small constant to avoid division by zero.
182+
:return: The total variance minimization loss.
183+
"""
184+
import torch
185+
186+
# Flatten inputs for pixel-wise loss
187+
x_flat = x.flatten()
188+
mask_flat = mask.flatten().float()
189+
190+
# Data fidelity term
191+
res = torch.sqrt(((z_init - x_flat) ** 2 * mask_flat).sum() + eps)
192+
193+
z2d = z_init.view(x.shape)
194+
195+
# Total variation terms
196+
if norm == 1:
197+
# L1 norm: sum of absolute differences per row/column
198+
tv_h = lamb * torch.abs(z2d[1:, :] - z2d[:-1, :]).sum(dim=1).sum()
199+
tv_w = lamb * torch.abs(z2d[:, 1:] - z2d[:, :-1]).sum(dim=0).sum()
200+
elif norm == 2:
201+
# L2 norm: sqrt of sum of squares per row/column, then sum
202+
tv_h = lamb * torch.sqrt(((z2d[1:, :] - z2d[:-1, :]) ** 2).sum(dim=1) + eps).sum()
203+
tv_w = lamb * torch.sqrt(((z2d[:, 1:] - z2d[:, :-1]) ** 2).sum(dim=0) + eps).sum()
204+
else:
205+
# General Lp norm
206+
tv_h = lamb * torch.pow(torch.abs(z2d[1:, :] - z2d[:-1, :]), norm).sum(dim=1).pow(1 / norm).sum()
207+
tv_w = lamb * torch.pow(torch.abs(z2d[:, 1:] - z2d[:, :-1]), norm).sum(dim=0).pow(1 / norm).sum()
208+
209+
tv = tv_h + tv_w
210+
211+
return res + tv
212+
213+
def _check_params(self) -> None:
214+
if not isinstance(self.prob, (float, int)) or self.prob < 0.0 or self.prob > 1.0:
215+
logger.error("Probability must be between 0 and 1.")
216+
raise ValueError("Probability must be between 0 and 1.")
217+
218+
if not isinstance(self.norm, int) or self.norm <= 0:
219+
logger.error("Norm must be a positive integer.")
220+
raise ValueError("Norm must be a positive integer.")
221+
222+
if not isinstance(self.max_iter, int) or self.max_iter <= 0:
223+
logger.error("Number of iterations must be a positive integer.")
224+
raise ValueError("Number of iterations must be a positive integer.")
225+
226+
if self.clip_values is not None and len(self.clip_values) != 2:
227+
raise ValueError("'clip_values' should be a tuple of 2 floats or arrays containing the allowed data range.")
228+
229+
if self.clip_values is not None and np.array(self.clip_values[0] >= self.clip_values[1]).any():
230+
raise ValueError("Invalid 'clip_values': min >= max.")
231+
232+
if not isinstance(self.verbose, bool):
233+
raise ValueError("The argument `verbose` has to be of type bool.")

0 commit comments

Comments
 (0)