File tree Expand file tree Collapse file tree 5 files changed +29
-40
lines changed Expand file tree Collapse file tree 5 files changed +29
-40
lines changed Original file line number Diff line number Diff line change @@ -67,19 +67,17 @@ def __init__(
6767 import torch # lgtm [py/repeated-import]
6868 from torch .autograd import Function
6969
70- super ().__init__ (is_fitted = True , apply_fit = apply_fit , apply_predict = apply_predict )
70+ super ().__init__ (
71+ device_type = device_type ,
72+ is_fitted = True ,
73+ apply_fit = apply_fit ,
74+ apply_predict = apply_predict ,
75+ )
7176 self .channels_first = channels_first
7277 self .sample_rate = sample_rate
7378 self .verbose = verbose
7479 self ._check_params ()
7580
76- # Set device
77- if device_type == "cpu" or not torch .cuda .is_available ():
78- self ._device = torch .device ("cpu" )
79- else : # pragma: no cover
80- cuda_idx = torch .cuda .current_device ()
81- self ._device = torch .device ("cuda:{}" .format (cuda_idx ))
82-
8381 self .compression_numpy = Mp3Compression (
8482 sample_rate = sample_rate ,
8583 channels_first = channels_first ,
Original file line number Diff line number Diff line change @@ -74,20 +74,17 @@ def __init__(
7474 """
7575 import torch # lgtm [py/repeated-import]
7676
77- super ().__init__ (apply_fit = apply_fit , apply_predict = apply_predict )
77+ super ().__init__ (
78+ device_type = device_type ,
79+ apply_fit = apply_fit ,
80+ apply_predict = apply_predict ,
81+ )
7882
7983 self .channels_first = channels_first
8084 self .window_size = window_size
8185 self .clip_values = clip_values
8286 self ._check_params ()
8387
84- # Set device
85- if device_type == "cpu" or not torch .cuda .is_available ():
86- self ._device = torch .device ("cpu" )
87- else : # pragma: no cover
88- cuda_idx = torch .cuda .current_device ()
89- self ._device = torch .device ("cuda:{}" .format (cuda_idx ))
90-
9188 from kornia .filters import MedianBlur
9289
9390 class MedianBlurCustom (MedianBlur ):
Original file line number Diff line number Diff line change @@ -73,20 +73,18 @@ def __init__(
7373 import torch # lgtm [py/repeated-import]
7474 from torch .autograd import Function
7575
76- super ().__init__ (is_fitted = True , apply_fit = apply_fit , apply_predict = apply_predict )
76+ super ().__init__ (
77+ device_type = device_type ,
78+ is_fitted = True ,
79+ apply_fit = apply_fit ,
80+ apply_predict = apply_predict ,
81+ )
7782 self .video_format = video_format
7883 self .constant_rate_factor = constant_rate_factor
7984 self .channels_first = channels_first
8085 self .verbose = verbose
8186 self ._check_params ()
8287
83- # Set device
84- if device_type == "cpu" or not torch .cuda .is_available ():
85- self ._device = torch .device ("cpu" )
86- else : # pragma: no cover
87- cuda_idx = torch .cuda .current_device ()
88- self ._device = torch .device ("cuda:{}" .format (cuda_idx ))
89-
9088 self .compression_numpy = VideoCompression (
9189 video_format = video_format ,
9290 constant_rate_factor = constant_rate_factor ,
Original file line number Diff line number Diff line change @@ -73,21 +73,19 @@ def __init__(
7373 """
7474 import torch # lgtm [py/repeated-import]
7575
76- super ().__init__ (is_fitted = True , apply_fit = apply_fit , apply_predict = apply_predict )
76+ super ().__init__ (
77+ device_type = device_type ,
78+ is_fitted = True ,
79+ apply_fit = apply_fit ,
80+ apply_predict = apply_predict ,
81+ )
7782
7883 self .numerator_coef = numerator_coef
7984 self .denominator_coef = denominator_coef
8085 self .clip_values = clip_values
8186 self .verbose = verbose
8287 self ._check_params ()
8388
84- # Set device
85- if device_type == "cpu" or not torch .cuda .is_available ():
86- self ._device = torch .device ("cpu" )
87- else : # pragma: no cover
88- cuda_idx = torch .cuda .current_device ()
89- self ._device = torch .device ("cuda:{}" .format (cuda_idx ))
90-
9189 def forward (
9290 self , x : "torch.Tensor" , y : Optional ["torch.Tensor" ] = None
9391 ) -> Tuple ["torch.Tensor" , Optional ["torch.Tensor" ]]:
Original file line number Diff line number Diff line change @@ -56,7 +56,12 @@ def __init__(
5656 """
5757 import torch # lgtm [py/repeated-import]
5858
59- super ().__init__ (is_fitted = True , apply_fit = apply_fit , apply_predict = apply_predict )
59+ super ().__init__ (
60+ device_type = device_type ,
61+ is_fitted = True ,
62+ apply_fit = apply_fit ,
63+ apply_predict = apply_predict ,
64+ )
6065 self .mean = np .asarray (mean , dtype = ART_NUMPY_DTYPE )
6166 self .std = np .asarray (std , dtype = ART_NUMPY_DTYPE )
6267 self ._check_params ()
@@ -65,13 +70,6 @@ def __init__(
6570 self ._broadcastable_mean = None
6671 self ._broadcastable_std = None
6772
68- # Set device
69- if device_type == "cpu" or not torch .cuda .is_available ():
70- self ._device = torch .device ("cpu" )
71- else : # pragma: no cover
72- cuda_idx = torch .cuda .current_device ()
73- self ._device = torch .device ("cuda:{}" .format (cuda_idx ))
74-
7573 def forward (
7674 self , x : "torch.Tensor" , y : Optional ["torch.Tensor" ] = None
7775 ) -> Tuple ["torch.Tensor" , Optional ["torch.Tensor" ]]:
You can’t perform that action at this time.
0 commit comments