@@ -51,7 +51,7 @@ def __init__(
5151 'plot_pref' : False ,
5252 },
5353 batch_size : int = 10 ,
54- DEVICE_compute = 'cpu' ,
54+ device = 'cpu' ,
5555 normalization_factor : float = 0.99 ,
5656 spectrogram_exponent : float = 1.0 ,
5757 one_over_f_exponent : float = 1.0 ,
@@ -62,7 +62,7 @@ def __init__(
6262 ## Set attributes
6363 self ._params_VQT = params_VQT
6464 self ._batch_size = int (batch_size )
65- self ._DEVICE_compute = DEVICE_compute
65+ self ._device = device
6666 self ._normalization_factor = float (normalization_factor )
6767 self ._spectrogram_exponent = float (spectrogram_exponent )
6868 self ._one_over_f_exponent = float (one_over_f_exponent )
@@ -84,7 +84,7 @@ def __init__(
8484 self .config = {
8585 'params_VQT' : params_VQT ,
8686 'batch_size' : batch_size ,
87- 'DEVICE_compute ' : DEVICE_compute ,
87+ 'device ' : device ,
8888 'normalization_factor' : normalization_factor ,
8989 'spectrogram_exponent' : spectrogram_exponent ,
9090 'one_over_f_exponent' : one_over_f_exponent ,
@@ -135,9 +135,9 @@ def transform(self, points_tracked: np.ndarray, point_positions: np.ndarray):
135135 freqs = self .vqt_model .freqs
136136 xAxis = self .vqt_model .get_xAxis (points_tracked .shape [- 1 ])
137137 ### send vqt_model to device
138- self .vqt_model .to (self ._DEVICE_compute )
138+ self .vqt_model .to (self ._device )
139139 spec = torch .cat ([
140- self .vqt_model (p .to (self ._DEVICE_compute )).cpu ()
140+ self .vqt_model (p .to (self ._device )).cpu ()
141141 for p in tqdm (helpers .make_batches (points_tracked , batch_size = self ._batch_size ), disable = not self ._verbose > 1 , desc = 'Computing spectrograms' , leave = True , position = 0 , total = int (math .ceil (points_tracked .shape [0 ] / self ._batch_size )))
142142 ], dim = 0 )
143143 self .vqt_model .to ('cpu' )
0 commit comments