Skip to content

Commit 4f41ea7

Browse files
authored
Merge pull request #1444 from nilakshdas/fix-1442
Set _device parameter automatically in PreprocessorPyTorch (fixes #1442)
2 parents a7c1f7b + 1377522 commit 4f41ea7

File tree

7 files changed

+74
-45
lines changed

7 files changed

+74
-45
lines changed

art/defences/preprocessor/mp3_compression_pytorch.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -62,22 +62,19 @@ def __init__(
6262
:param device_type: Type of device on which the classifier is run, either `gpu` or `cpu`.
6363
:param verbose: Show progress bars.
6464
"""
65-
import torch # lgtm [py/repeated-import]
6665
from torch.autograd import Function
6766

68-
super().__init__(is_fitted=True, apply_fit=apply_fit, apply_predict=apply_predict)
67+
super().__init__(
68+
device_type=device_type,
69+
is_fitted=True,
70+
apply_fit=apply_fit,
71+
apply_predict=apply_predict,
72+
)
6973
self.channels_first = channels_first
7074
self.sample_rate = sample_rate
7175
self.verbose = verbose
7276
self._check_params()
7377

74-
# Set device
75-
if device_type == "cpu" or not torch.cuda.is_available():
76-
self._device = torch.device("cpu")
77-
else: # pragma: no cover
78-
cuda_idx = torch.cuda.current_device()
79-
self._device = torch.device("cuda:{}".format(cuda_idx))
80-
8178
self.compression_numpy = Mp3Compression(
8279
sample_rate=sample_rate,
8380
channels_first=channels_first,

art/defences/preprocessor/preprocessor.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,18 @@ class PreprocessorPyTorch(Preprocessor):
140140
Abstract base class for preprocessing defences implemented in PyTorch that support efficient preprocessor-chaining.
141141
"""
142142

143+
def __init__(self, device_type: str = "gpu", **kwargs):
144+
import torch # lgtm [py/repeated-import]
145+
146+
super().__init__(**kwargs)
147+
148+
# Set device
149+
if device_type == "cpu" or not torch.cuda.is_available():
150+
self._device = torch.device("cpu")
151+
else: # pragma: no cover
152+
cuda_idx = torch.cuda.current_device()
153+
self._device = torch.device("cuda:{}".format(cuda_idx))
154+
143155
@abc.abstractmethod
144156
def forward(
145157
self, x: "torch.Tensor", y: Optional["torch.Tensor"] = None

art/defences/preprocessor/spatial_smoothing_pytorch.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -72,22 +72,18 @@ def __init__(
7272
:param apply_predict: True if applied during predicting.
7373
:param device_type: Type of device on which the classifier is run, either `gpu` or `cpu`.
7474
"""
75-
import torch # lgtm [py/repeated-import]
7675

77-
super().__init__(apply_fit=apply_fit, apply_predict=apply_predict)
76+
super().__init__(
77+
device_type=device_type,
78+
apply_fit=apply_fit,
79+
apply_predict=apply_predict,
80+
)
7881

7982
self.channels_first = channels_first
8083
self.window_size = window_size
8184
self.clip_values = clip_values
8285
self._check_params()
8386

84-
# Set device
85-
if device_type == "cpu" or not torch.cuda.is_available():
86-
self._device = torch.device("cpu")
87-
else: # pragma: no cover
88-
cuda_idx = torch.cuda.current_device()
89-
self._device = torch.device("cuda:{}".format(cuda_idx))
90-
9187
from kornia.filters import MedianBlur
9288

9389
class MedianBlurCustom(MedianBlur):

art/defences/preprocessor/video_compression_pytorch.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -68,23 +68,20 @@ def __init__(
6868
:param device_type: Type of device on which the classifier is run, either `gpu` or `cpu`.
6969
:param verbose: Show progress bars.
7070
"""
71-
import torch # lgtm [py/repeated-import]
7271
from torch.autograd import Function
7372

74-
super().__init__(is_fitted=True, apply_fit=apply_fit, apply_predict=apply_predict)
73+
super().__init__(
74+
device_type=device_type,
75+
is_fitted=True,
76+
apply_fit=apply_fit,
77+
apply_predict=apply_predict,
78+
)
7579
self.video_format = video_format
7680
self.constant_rate_factor = constant_rate_factor
7781
self.channels_first = channels_first
7882
self.verbose = verbose
7983
self._check_params()
8084

81-
# Set device
82-
if device_type == "cpu" or not torch.cuda.is_available():
83-
self._device = torch.device("cpu")
84-
else: # pragma: no cover
85-
cuda_idx = torch.cuda.current_device()
86-
self._device = torch.device("cuda:{}".format(cuda_idx))
87-
8885
self.compression_numpy = VideoCompression(
8986
video_format=video_format,
9087
constant_rate_factor=constant_rate_factor,

art/preprocessing/audio/l_filter/pytorch.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -71,23 +71,20 @@ def __init__(
7171
:param verbose: Show progress bars.
7272
:param device_type: Type of device on which the classifier is run, either `gpu` or `cpu`.
7373
"""
74-
import torch # lgtm [py/repeated-import]
7574

76-
super().__init__(is_fitted=True, apply_fit=apply_fit, apply_predict=apply_predict)
75+
super().__init__(
76+
device_type=device_type,
77+
is_fitted=True,
78+
apply_fit=apply_fit,
79+
apply_predict=apply_predict,
80+
)
7781

7882
self.numerator_coef = numerator_coef
7983
self.denominator_coef = denominator_coef
8084
self.clip_values = clip_values
8185
self.verbose = verbose
8286
self._check_params()
8387

84-
# Set device
85-
if device_type == "cpu" or not torch.cuda.is_available():
86-
self._device = torch.device("cpu")
87-
else: # pragma: no cover
88-
cuda_idx = torch.cuda.current_device()
89-
self._device = torch.device("cuda:{}".format(cuda_idx))
90-
9188
def forward(
9289
self, x: "torch.Tensor", y: Optional["torch.Tensor"] = None
9390
) -> Tuple["torch.Tensor", Optional["torch.Tensor"]]:

art/preprocessing/standardisation_mean_std/pytorch.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,13 @@ def __init__(
5454
:param mean: Mean.
5555
:param std: Standard Deviation.
5656
"""
57-
import torch # lgtm [py/repeated-import]
5857

59-
super().__init__(is_fitted=True, apply_fit=apply_fit, apply_predict=apply_predict)
58+
super().__init__(
59+
device_type=device_type,
60+
is_fitted=True,
61+
apply_fit=apply_fit,
62+
apply_predict=apply_predict,
63+
)
6064
self.mean = np.asarray(mean, dtype=ART_NUMPY_DTYPE)
6165
self.std = np.asarray(std, dtype=ART_NUMPY_DTYPE)
6266
self._check_params()
@@ -65,13 +69,6 @@ def __init__(
6569
self._broadcastable_mean = None
6670
self._broadcastable_std = None
6771

68-
# Set device
69-
if device_type == "cpu" or not torch.cuda.is_available():
70-
self._device = torch.device("cpu")
71-
else: # pragma: no cover
72-
cuda_idx = torch.cuda.current_device()
73-
self._device = torch.device("cuda:{}".format(cuda_idx))
74-
7572
def forward(
7673
self, x: "torch.Tensor", y: Optional["torch.Tensor"] = None
7774
) -> Tuple["torch.Tensor", Optional["torch.Tensor"]]:
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import pytest
2+
3+
from art.defences.preprocessor.preprocessor import PreprocessorPyTorch
4+
from tests.utils import ARTTestException
5+
6+
7+
class DummyPreprocessorPyTorch(PreprocessorPyTorch):
8+
def forward(self, x, y):
9+
return x, y
10+
11+
12+
@pytest.mark.parametrize("is_fitted", [True, False])
13+
@pytest.mark.parametrize("apply_fit", [True, False])
14+
@pytest.mark.parametrize("apply_predict", [True, False])
15+
@pytest.mark.only_with_platform("pytorch")
16+
def test_preprocessor_pytorch_init(art_warning, is_fitted, apply_fit, apply_predict):
17+
try:
18+
import torch
19+
20+
preprocessor = DummyPreprocessorPyTorch(
21+
device_type="cpu",
22+
is_fitted=is_fitted,
23+
apply_fit=apply_fit,
24+
apply_predict=apply_predict,
25+
)
26+
27+
assert preprocessor.device == torch.device("cpu")
28+
assert preprocessor.is_fitted == is_fitted
29+
assert preprocessor.apply_fit == apply_fit
30+
assert preprocessor.apply_predict == apply_predict
31+
32+
except ARTTestException as e:
33+
art_warning(e)

0 commit comments

Comments
 (0)