|
15 | 15 |
|
16 | 16 | from tqdm.auto import tqdm
|
17 | 17 |
|
18 |
| -import torch |
19 | 18 | import numpy as np
|
20 | 19 |
|
21 | 20 | from art.defences.preprocessor.preprocessor import PreprocessorPyTorch
|
@@ -65,7 +64,11 @@ def __init__(
|
65 | 64 | :param verbose: Show progress bars.
|
66 | 65 | :param device_type: Type of device on which the classifier is run, either `gpu` or `cpu`.
|
67 | 66 | """
|
68 |
| - super().__init__(device_type=device_type, apply_fit=apply_fit, apply_predict=apply_predict) |
| 67 | + super().__init__( |
| 68 | + device_type=device_type, |
| 69 | + apply_fit=apply_fit, |
| 70 | + apply_predict=apply_predict, |
| 71 | + ) |
69 | 72 |
|
70 | 73 | self.prob = prob
|
71 | 74 | self.norm = norm
|
@@ -107,16 +110,18 @@ def forward(
|
107 | 110 | if torch.sum(mask) > 0:
|
108 | 111 | x_preproc[i] = self._minimize(x_preproc[i], mask)
|
109 | 112 |
|
110 |
| - if self.clip_values is not None: |
111 |
| - x_preproc = torch.clamp(x_preproc, self.clip_values[0], self.clip_values[1]) |
112 |
| - |
| 113 | + # BCHW -> BHWC |
113 | 114 | if not self.channels_first:
|
114 |
| - # BCHW -> BHWC |
115 | 115 | x_preproc = x_preproc.permute(0, 2, 3, 1)
|
116 | 116 |
|
| 117 | + if self.clip_values is not None: |
| 118 | + clip_min = torch.tensor(self.clip_values[0], device=x_preproc.device) |
| 119 | + clip_max = torch.tensor(self.clip_values[1], device=x_preproc.device) |
| 120 | + x_preproc = x_preproc.clamp(min=clip_min, max=clip_max) |
| 121 | + |
117 | 122 | return x_preproc, y
|
118 | 123 |
|
119 |
| - def _minimize(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: |
| 124 | + def _minimize(self, x: "torch.Tensor", mask: "torch.Tensor") -> "torch.Tensor": |
120 | 125 | """
|
121 | 126 | Minimize the total variance objective function for a single 3D image by
|
122 | 127 | iterating through its channels.
|
@@ -162,8 +167,8 @@ def closure():
|
162 | 167 |
|
163 | 168 | @staticmethod
|
164 | 169 | def _loss_func(
|
165 |
| - z_init: torch.Tensor, x: torch.Tensor, mask: torch.Tensor, norm: float, lamb: float, eps=1e-6 |
166 |
| - ) -> torch.Tensor: |
| 170 | + z_init: "torch.Tensor", x: "torch.Tensor", mask: "torch.Tensor", norm: float, lamb: float, eps=1e-6 |
| 171 | + ) -> "torch.Tensor": |
167 | 172 | """
|
168 | 173 | Loss function to be minimized - try to match SciPy implementation closely.
|
169 | 174 |
|
@@ -216,15 +221,11 @@ def _check_params(self) -> None:
|
216 | 221 | logger.error("Number of iterations must be a positive integer.")
|
217 | 222 | raise ValueError("Number of iterations must be a positive integer.")
|
218 | 223 |
|
219 |
| - if self.clip_values is not None: |
220 |
| - |
221 |
| - if len(self.clip_values) != 2: |
222 |
| - raise ValueError( |
223 |
| - "'clip_values' should be a tuple of 2 floats or arrays containing the allowed data range." |
224 |
| - ) |
| 224 | + if self.clip_values is not None and len(self.clip_values) != 2: |
| 225 | + raise ValueError("'clip_values' should be a tuple of 2 floats or arrays containing the allowed data range.") |
225 | 226 |
|
226 |
| - if np.array(self.clip_values[0] >= self.clip_values[1]).any(): |
227 |
| - raise ValueError("Invalid 'clip_values': min >= max.") |
| 227 | + if self.clip_values is not None and np.array(self.clip_values[0] >= self.clip_values[1]).any(): |
| 228 | + raise ValueError("Invalid 'clip_values': min >= max.") |
228 | 229 |
|
229 | 230 | if not isinstance(self.verbose, bool):
|
230 | 231 | raise ValueError("The argument `verbose` has to be of type bool.")
|
0 commit comments