Skip to content

Commit 5c84b3a

Browse files
committed
Set device using parent class constructor
Signed-off-by: Nilaksh Das <[email protected]>
1 parent 4f3df77 commit 5c84b3a

File tree

5 files changed

+29
-40
lines changed

5 files changed

+29
-40
lines changed

art/defences/preprocessor/mp3_compression_pytorch.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -67,19 +67,17 @@ def __init__(
6767
import torch # lgtm [py/repeated-import]
6868
from torch.autograd import Function
6969

70-
super().__init__(is_fitted=True, apply_fit=apply_fit, apply_predict=apply_predict)
70+
super().__init__(
71+
device_type=device_type,
72+
is_fitted=True,
73+
apply_fit=apply_fit,
74+
apply_predict=apply_predict,
75+
)
7176
self.channels_first = channels_first
7277
self.sample_rate = sample_rate
7378
self.verbose = verbose
7479
self._check_params()
7580

76-
# Set device
77-
if device_type == "cpu" or not torch.cuda.is_available():
78-
self._device = torch.device("cpu")
79-
else: # pragma: no cover
80-
cuda_idx = torch.cuda.current_device()
81-
self._device = torch.device("cuda:{}".format(cuda_idx))
82-
8381
self.compression_numpy = Mp3Compression(
8482
sample_rate=sample_rate,
8583
channels_first=channels_first,

art/defences/preprocessor/spatial_smoothing_pytorch.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -74,20 +74,17 @@ def __init__(
7474
"""
7575
import torch # lgtm [py/repeated-import]
7676

77-
super().__init__(apply_fit=apply_fit, apply_predict=apply_predict)
77+
super().__init__(
78+
device_type=device_type,
79+
apply_fit=apply_fit,
80+
apply_predict=apply_predict,
81+
)
7882

7983
self.channels_first = channels_first
8084
self.window_size = window_size
8185
self.clip_values = clip_values
8286
self._check_params()
8387

84-
# Set device
85-
if device_type == "cpu" or not torch.cuda.is_available():
86-
self._device = torch.device("cpu")
87-
else: # pragma: no cover
88-
cuda_idx = torch.cuda.current_device()
89-
self._device = torch.device("cuda:{}".format(cuda_idx))
90-
9188
from kornia.filters import MedianBlur
9289

9390
class MedianBlurCustom(MedianBlur):

art/defences/preprocessor/video_compression_pytorch.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -73,20 +73,18 @@ def __init__(
7373
import torch # lgtm [py/repeated-import]
7474
from torch.autograd import Function
7575

76-
super().__init__(is_fitted=True, apply_fit=apply_fit, apply_predict=apply_predict)
76+
super().__init__(
77+
device_type=device_type,
78+
is_fitted=True,
79+
apply_fit=apply_fit,
80+
apply_predict=apply_predict,
81+
)
7782
self.video_format = video_format
7883
self.constant_rate_factor = constant_rate_factor
7984
self.channels_first = channels_first
8085
self.verbose = verbose
8186
self._check_params()
8287

83-
# Set device
84-
if device_type == "cpu" or not torch.cuda.is_available():
85-
self._device = torch.device("cpu")
86-
else: # pragma: no cover
87-
cuda_idx = torch.cuda.current_device()
88-
self._device = torch.device("cuda:{}".format(cuda_idx))
89-
9088
self.compression_numpy = VideoCompression(
9189
video_format=video_format,
9290
constant_rate_factor=constant_rate_factor,

art/preprocessing/audio/l_filter/pytorch.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -73,21 +73,19 @@ def __init__(
7373
"""
7474
import torch # lgtm [py/repeated-import]
7575

76-
super().__init__(is_fitted=True, apply_fit=apply_fit, apply_predict=apply_predict)
76+
super().__init__(
77+
device_type=device_type,
78+
is_fitted=True,
79+
apply_fit=apply_fit,
80+
apply_predict=apply_predict,
81+
)
7782

7883
self.numerator_coef = numerator_coef
7984
self.denominator_coef = denominator_coef
8085
self.clip_values = clip_values
8186
self.verbose = verbose
8287
self._check_params()
8388

84-
# Set device
85-
if device_type == "cpu" or not torch.cuda.is_available():
86-
self._device = torch.device("cpu")
87-
else: # pragma: no cover
88-
cuda_idx = torch.cuda.current_device()
89-
self._device = torch.device("cuda:{}".format(cuda_idx))
90-
9189
def forward(
9290
self, x: "torch.Tensor", y: Optional["torch.Tensor"] = None
9391
) -> Tuple["torch.Tensor", Optional["torch.Tensor"]]:

art/preprocessing/standardisation_mean_std/pytorch.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,12 @@ def __init__(
5656
"""
5757
import torch # lgtm [py/repeated-import]
5858

59-
super().__init__(is_fitted=True, apply_fit=apply_fit, apply_predict=apply_predict)
59+
super().__init__(
60+
device_type=device_type,
61+
is_fitted=True,
62+
apply_fit=apply_fit,
63+
apply_predict=apply_predict,
64+
)
6065
self.mean = np.asarray(mean, dtype=ART_NUMPY_DTYPE)
6166
self.std = np.asarray(std, dtype=ART_NUMPY_DTYPE)
6267
self._check_params()
@@ -65,13 +70,6 @@ def __init__(
6570
self._broadcastable_mean = None
6671
self._broadcastable_std = None
6772

68-
# Set device
69-
if device_type == "cpu" or not torch.cuda.is_available():
70-
self._device = torch.device("cpu")
71-
else: # pragma: no cover
72-
cuda_idx = torch.cuda.current_device()
73-
self._device = torch.device("cuda:{}".format(cuda_idx))
74-
7573
def forward(
7674
self, x: "torch.Tensor", y: Optional["torch.Tensor"] = None
7775
) -> Tuple["torch.Tensor", Optional["torch.Tensor"]]:

0 commit comments

Comments
 (0)