Skip to content

Commit e98fa82

Browse files
committed
Fix clamp tensor arguments with proper device placement and signature
Signed-off-by: Matteo Fasulo <[email protected]>
1 parent dc992d1 commit e98fa82

File tree

1 file changed

+18
-17
lines changed

1 file changed

+18
-17
lines changed

art/defences/preprocessor/variance_minimization_pytorch.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
from tqdm.auto import tqdm
1717

18-
import torch
1918
import numpy as np
2019

2120
from art.defences.preprocessor.preprocessor import PreprocessorPyTorch
@@ -65,7 +64,11 @@ def __init__(
6564
:param verbose: Show progress bars.
6665
:param device_type: Type of device on which the classifier is run, either `gpu` or `cpu`.
6766
"""
68-
super().__init__(device_type=device_type, apply_fit=apply_fit, apply_predict=apply_predict)
67+
super().__init__(
68+
device_type=device_type,
69+
apply_fit=apply_fit,
70+
apply_predict=apply_predict,
71+
)
6972

7073
self.prob = prob
7174
self.norm = norm
@@ -107,16 +110,18 @@ def forward(
107110
if torch.sum(mask) > 0:
108111
x_preproc[i] = self._minimize(x_preproc[i], mask)
109112

110-
if self.clip_values is not None:
111-
x_preproc = torch.clamp(x_preproc, self.clip_values[0], self.clip_values[1])
112-
113+
# BCHW -> BHWC
113114
if not self.channels_first:
114-
# BCHW -> BHWC
115115
x_preproc = x_preproc.permute(0, 2, 3, 1)
116116

117+
if self.clip_values is not None:
118+
clip_min = torch.tensor(self.clip_values[0], device=x_preproc.device)
119+
clip_max = torch.tensor(self.clip_values[1], device=x_preproc.device)
120+
x_preproc = x_preproc.clamp(min=clip_min, max=clip_max)
121+
117122
return x_preproc, y
118123

119-
def _minimize(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
124+
def _minimize(self, x: "torch.Tensor", mask: "torch.Tensor") -> "torch.Tensor":
120125
"""
121126
Minimize the total variance objective function for a single 3D image by
122127
iterating through its channels.
@@ -162,8 +167,8 @@ def closure():
162167

163168
@staticmethod
164169
def _loss_func(
165-
z_init: torch.Tensor, x: torch.Tensor, mask: torch.Tensor, norm: float, lamb: float, eps=1e-6
166-
) -> torch.Tensor:
170+
z_init: "torch.Tensor", x: "torch.Tensor", mask: "torch.Tensor", norm: float, lamb: float, eps=1e-6
171+
) -> "torch.Tensor":
167172
"""
168173
Loss function to be minimized - try to match SciPy implementation closely.
169174
@@ -216,15 +221,11 @@ def _check_params(self) -> None:
216221
logger.error("Number of iterations must be a positive integer.")
217222
raise ValueError("Number of iterations must be a positive integer.")
218223

219-
if self.clip_values is not None:
220-
221-
if len(self.clip_values) != 2:
222-
raise ValueError(
223-
"'clip_values' should be a tuple of 2 floats or arrays containing the allowed data range."
224-
)
224+
if self.clip_values is not None and len(self.clip_values) != 2:
225+
raise ValueError("'clip_values' should be a tuple of 2 floats or arrays containing the allowed data range.")
225226

226-
if np.array(self.clip_values[0] >= self.clip_values[1]).any():
227-
raise ValueError("Invalid 'clip_values': min >= max.")
227+
if self.clip_values is not None and np.array(self.clip_values[0] >= self.clip_values[1]).any():
228+
raise ValueError("Invalid 'clip_values': min >= max.")
228229

229230
if not isinstance(self.verbose, bool):
230231
raise ValueError("The argument `verbose` has to be of type bool.")

0 commit comments

Comments
 (0)