Skip to content

Commit 77702f4

Browse files
committed
device bug
1 parent 64a19e5 commit 77702f4

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

face_rhythm/pipelines.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ def pipeline_basic(params):
222222
## - `downsample_factor`: How much to downsample the spectrogram by in time.
223223
## - `return_complex`: Whether or not to return the complex spectrogram. Generally set to False unless you want to try something fancy.
224224

225-
params['VQT_Analyzer']['params_VQT']['DEVICE_compute'] = fr.helpers.set_device(use_GPU=use_GPU)
225+
params['VQT_Analyzer']['device'] = fr.helpers.set_device(use_GPU=use_GPU)
226226

227227
spec = fr.spectral_analysis.VQT_Analyzer(**params['VQT_Analyzer'])
228228

face_rhythm/spectral_analysis.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def __init__(
5151
'plot_pref': False,
5252
},
5353
batch_size: int=10,
54-
DEVICE_compute='cpu',
54+
device='cpu',
5555
normalization_factor: float=0.99,
5656
spectrogram_exponent: float=1.0,
5757
one_over_f_exponent: float=1.0,
@@ -62,7 +62,7 @@ def __init__(
6262
## Set attributes
6363
self._params_VQT = params_VQT
6464
self._batch_size = int(batch_size)
65-
self._DEVICE_compute = DEVICE_compute
65+
self._device = device
6666
self._normalization_factor = float(normalization_factor)
6767
self._spectrogram_exponent = float(spectrogram_exponent)
6868
self._one_over_f_exponent = float(one_over_f_exponent)
@@ -84,7 +84,7 @@ def __init__(
8484
self.config = {
8585
'params_VQT': params_VQT,
8686
'batch_size': batch_size,
87-
'DEVICE_compute': DEVICE_compute,
87+
'device': device,
8888
'normalization_factor': normalization_factor,
8989
'spectrogram_exponent': spectrogram_exponent,
9090
'one_over_f_exponent': one_over_f_exponent,
@@ -135,9 +135,9 @@ def transform(self, points_tracked: np.ndarray, point_positions: np.ndarray):
135135
freqs = self.vqt_model.freqs
136136
xAxis = self.vqt_model.get_xAxis(points_tracked.shape[-1])
137137
### send vqt_model to device
138-
self.vqt_model.to(self._DEVICE_compute)
138+
self.vqt_model.to(self._device)
139139
spec = torch.cat([
140-
self.vqt_model(p.to(self._DEVICE_compute)).cpu()
140+
self.vqt_model(p.to(self._device)).cpu()
141141
for p in tqdm(helpers.make_batches(points_tracked, batch_size=self._batch_size), disable=not self._verbose > 1, desc='Computing spectrograms', leave=True, position=0, total=int(math.ceil(points_tracked.shape[0] / self._batch_size)))
142142
], dim=0)
143143
self.vqt_model.to('cpu')

0 commit comments

Comments
 (0)