|
| 1 | +""" |
| 2 | +This module implements the total variance minimization defence `TotalVarMin` in PyTorch. |
| 3 | +
|
| 4 | +| Paper link: https://openreview.net/forum?id=SyJ7ClWCb |
| 5 | +
|
| 6 | +| Please keep in mind the limitations of defences. For more information on the limitations of this defence, |
| 7 | + see https://arxiv.org/abs/1802.00420 . For details on how to evaluate classifier security in general, see |
| 8 | + https://arxiv.org/abs/1902.06705 |
| 9 | +""" |
| 10 | + |
| 11 | +from __future__ import absolute_import, division, print_function, unicode_literals, annotations |
| 12 | + |
| 13 | +import logging |
| 14 | +from typing import TYPE_CHECKING |
| 15 | + |
| 16 | +from tqdm.auto import tqdm |
| 17 | + |
| 18 | +import torch |
| 19 | + |
| 20 | +import numpy as np |
| 21 | + |
| 22 | +from art.defences.preprocessor.preprocessor import PreprocessorPyTorch |
| 23 | + |
| 24 | +if TYPE_CHECKING: |
| 25 | + from art.utils import CLIP_VALUES_TYPE |
| 26 | + |
| 27 | +logger = logging.getLogger(__name__) |
| 28 | + |
| 29 | + |
| 30 | +class TotalVarMinPyTorch(PreprocessorPyTorch): |
| 31 | + """ |
| 32 | + Implement the total variance minimization defence approach in PyTorch. |
| 33 | +
|
| 34 | + | Paper link: https://openreview.net/forum?id=SyJ7ClWCb |
| 35 | +
|
| 36 | + | Please keep in mind the limitations of defences. For more information on the limitations of this |
| 37 | + defence, see https://arxiv.org/abs/1802.00420 . For details on how to evaluate classifier security in general, |
| 38 | + see https://arxiv.org/abs/1902.06705 |
| 39 | + """ |
| 40 | + |
| 41 | + def __init__( |
| 42 | + self, |
| 43 | + prob: float = 0.3, |
| 44 | + norm: int = 1, |
| 45 | + lamb: float = 0.5, |
| 46 | + max_iter: int = 10, |
| 47 | + channels_first: bool = True, |
| 48 | + clip_values: "CLIP_VALUES_TYPE | None" = None, |
| 49 | + apply_fit: bool = False, |
| 50 | + apply_predict: bool = True, |
| 51 | + verbose: bool = False, |
| 52 | + device_type: str = "gpu", |
| 53 | + ) -> None: |
| 54 | + """ |
| 55 | + Create an instance of total variance minimization in PyTorch. |
| 56 | +
|
| 57 | + :param prob: Probability of the Bernoulli distribution. |
| 58 | + :param norm: The norm (positive integer). |
| 59 | + :param lamb: The lambda parameter in the objective function. |
| 60 | + :param max_iter: Maximum number of iterations when performing optimization. |
| 61 | + :param channels_first: Set channels first or last. |
| 62 | + :param clip_values: Tuple of the form `(min, max)` representing the minimum and maximum values allowed |
| 63 | + for features. |
| 64 | + :param apply_fit: True if applied during fitting/training. |
| 65 | + :param apply_predict: True if applied during predicting. |
| 66 | + :param verbose: Show progress bars. |
| 67 | + :param device_type: Type of device on which the classifier is run, either `gpu` or `cpu`. |
| 68 | + """ |
| 69 | + super().__init__( |
| 70 | + device_type=device_type, |
| 71 | + apply_fit=apply_fit, |
| 72 | + apply_predict=apply_predict, |
| 73 | + ) |
| 74 | + |
| 75 | + self.prob = prob |
| 76 | + self.norm = norm |
| 77 | + self.lamb = lamb |
| 78 | + self.max_iter = max_iter |
| 79 | + self.channels_first = channels_first |
| 80 | + self.clip_values = clip_values |
| 81 | + self.verbose = verbose |
| 82 | + self._check_params() |
| 83 | + |
| 84 | + def forward( |
| 85 | + self, x: "torch.Tensor", y: "torch.Tensor | None" = None |
| 86 | + ) -> tuple["torch.Tensor", "torch.Tensor | None"]: |
| 87 | + """ |
| 88 | + Apply total variance minimization to sample `x`. |
| 89 | +
|
| 90 | + :param x: Sample to compress with shape `(batch_size, channels, height, width)`. |
| 91 | + :param y: Labels of the sample `x`. This function does not affect them in any way. |
| 92 | + :return: Similar samples. |
| 93 | + """ |
| 94 | + import torch |
| 95 | + |
| 96 | + if len(x.shape) != 4: |
| 97 | + raise ValueError("Input `x` must be a 4D tensor (batch, channels, width, height).") |
| 98 | + |
| 99 | + if not self.channels_first: |
| 100 | + # BHWC -> BCHW |
| 101 | + x = x.permute(0, 3, 1, 2) |
| 102 | + |
| 103 | + x_preproc = x.clone() |
| 104 | + |
| 105 | + B, C, H, W = x_preproc.shape |
| 106 | + |
| 107 | + # Minimize one input at a time (iterate over the batch dimension) |
| 108 | + for i in tqdm(range(B), desc="Variance minimization", disable=not self.verbose): |
| 109 | + mask = (torch.rand_like(x_preproc[i]) < self.prob).float() |
| 110 | + |
| 111 | + # Skip optimization if mask is all zeros (prob=0.0 case) |
| 112 | + if torch.sum(mask) > 0: |
| 113 | + x_preproc[i] = self._minimize(x_preproc[i], mask) |
| 114 | + |
| 115 | + # BCHW -> BHWC |
| 116 | + if not self.channels_first: |
| 117 | + x_preproc = x_preproc.permute(0, 2, 3, 1) |
| 118 | + |
| 119 | + if self.clip_values is not None: |
| 120 | + clip_min = torch.tensor(self.clip_values[0], device=x_preproc.device) |
| 121 | + clip_max = torch.tensor(self.clip_values[1], device=x_preproc.device) |
| 122 | + x_preproc = x_preproc.clamp(min=clip_min, max=clip_max) |
| 123 | + |
| 124 | + return x_preproc, y |
| 125 | + |
| 126 | + def _minimize(self, x: "torch.Tensor", mask: "torch.Tensor") -> "torch.Tensor": |
| 127 | + """ |
| 128 | + Minimize the total variance objective function for a single 3D image by |
| 129 | + iterating through its channels. |
| 130 | +
|
| 131 | + :param x: Original image. |
| 132 | + :param mask: A matrix that decides which points are kept. |
| 133 | + :return: A new image. |
| 134 | + """ |
| 135 | + import torch |
| 136 | + |
| 137 | + # Create a tensor to hold the final results for each channel |
| 138 | + z_min = x.clone() |
| 139 | + C, H, W = x.shape |
| 140 | + |
| 141 | + # Iterate over each channel of the single image |
| 142 | + for c in range(C): |
| 143 | + # Skip channel if no mask points in this channel |
| 144 | + if torch.sum(mask[c, :, :]) == 0: |
| 145 | + continue |
| 146 | + |
| 147 | + # Create a separate, optimizable variable for the current channel |
| 148 | + res = x[c, :, :].clone().detach().requires_grad_(True) |
| 149 | + |
| 150 | + # The optimizer works on this specific channel variable |
| 151 | + optimizer = torch.optim.LBFGS([res], max_iter=self.max_iter) |
| 152 | + |
| 153 | + def closure(): |
| 154 | + optimizer.zero_grad() |
| 155 | + # Loss is calculated only for the current 2D channel |
| 156 | + loss = self._loss_func( |
| 157 | + z_init=res.flatten(), x=x[c, :, :], mask=mask[c, :, :], norm=self.norm, lamb=self.lamb |
| 158 | + ) |
| 159 | + loss.backward(retain_graph=True) |
| 160 | + return loss |
| 161 | + |
| 162 | + optimizer.step(closure) |
| 163 | + |
| 164 | + # Place the optimized channel back into our result tensor |
| 165 | + with torch.no_grad(): |
| 166 | + z_min[c, :, :] = res.view_as(z_min[c, :, :]) |
| 167 | + |
| 168 | + return z_min |
| 169 | + |
| 170 | + @staticmethod |
| 171 | + def _loss_func( |
| 172 | + z_init: "torch.Tensor", x: "torch.Tensor", mask: "torch.Tensor", norm: float, lamb: float, eps: float = 1e-6 |
| 173 | + ) -> "torch.Tensor": |
| 174 | + """ |
| 175 | + Calculate the total variance minimization loss function. |
| 176 | + :param z_init: Initial guess for the optimization. |
| 177 | + :param x: Original image. |
| 178 | + :param mask: Mask indicating which pixels to consider. |
| 179 | + :param norm: The norm to use (1, 2, or p). |
| 180 | + :param lamb: The lambda parameter in the objective function. |
| 181 | + :param eps: Small constant to avoid division by zero. |
| 182 | + :return: The total variance minimization loss. |
| 183 | + """ |
| 184 | + import torch |
| 185 | + |
| 186 | + # Flatten inputs for pixel-wise loss |
| 187 | + x_flat = x.flatten() |
| 188 | + mask_flat = mask.flatten().float() |
| 189 | + |
| 190 | + # Data fidelity term |
| 191 | + res = torch.sqrt(((z_init - x_flat) ** 2 * mask_flat).sum() + eps) |
| 192 | + |
| 193 | + z2d = z_init.view(x.shape) |
| 194 | + |
| 195 | + # Total variation terms |
| 196 | + if norm == 1: |
| 197 | + # L1 norm: sum of absolute differences per row/column |
| 198 | + tv_h = lamb * torch.abs(z2d[1:, :] - z2d[:-1, :]).sum(dim=1).sum() |
| 199 | + tv_w = lamb * torch.abs(z2d[:, 1:] - z2d[:, :-1]).sum(dim=0).sum() |
| 200 | + elif norm == 2: |
| 201 | + # L2 norm: sqrt of sum of squares per row/column, then sum |
| 202 | + tv_h = lamb * torch.sqrt(((z2d[1:, :] - z2d[:-1, :]) ** 2).sum(dim=1) + eps).sum() |
| 203 | + tv_w = lamb * torch.sqrt(((z2d[:, 1:] - z2d[:, :-1]) ** 2).sum(dim=0) + eps).sum() |
| 204 | + else: |
| 205 | + # General Lp norm |
| 206 | + tv_h = lamb * torch.pow(torch.abs(z2d[1:, :] - z2d[:-1, :]), norm).sum(dim=1).pow(1 / norm).sum() |
| 207 | + tv_w = lamb * torch.pow(torch.abs(z2d[:, 1:] - z2d[:, :-1]), norm).sum(dim=0).pow(1 / norm).sum() |
| 208 | + |
| 209 | + tv = tv_h + tv_w |
| 210 | + |
| 211 | + return res + tv |
| 212 | + |
| 213 | + def _check_params(self) -> None: |
| 214 | + if not isinstance(self.prob, (float, int)) or self.prob < 0.0 or self.prob > 1.0: |
| 215 | + logger.error("Probability must be between 0 and 1.") |
| 216 | + raise ValueError("Probability must be between 0 and 1.") |
| 217 | + |
| 218 | + if not isinstance(self.norm, int) or self.norm <= 0: |
| 219 | + logger.error("Norm must be a positive integer.") |
| 220 | + raise ValueError("Norm must be a positive integer.") |
| 221 | + |
| 222 | + if not isinstance(self.max_iter, int) or self.max_iter <= 0: |
| 223 | + logger.error("Number of iterations must be a positive integer.") |
| 224 | + raise ValueError("Number of iterations must be a positive integer.") |
| 225 | + |
| 226 | + if self.clip_values is not None and len(self.clip_values) != 2: |
| 227 | + raise ValueError("'clip_values' should be a tuple of 2 floats or arrays containing the allowed data range.") |
| 228 | + |
| 229 | + if self.clip_values is not None and np.array(self.clip_values[0] >= self.clip_values[1]).any(): |
| 230 | + raise ValueError("Invalid 'clip_values': min >= max.") |
| 231 | + |
| 232 | + if not isinstance(self.verbose, bool): |
| 233 | + raise ValueError("The argument `verbose` has to be of type bool.") |
0 commit comments