Skip to content

Commit fd03959

Browse files
author
Beat Buesser
committed
Update PyTorch DeepSpeech
Signed-off-by: Beat Buesser <[email protected]>
1 parent 8e444ba commit fd03959

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

art/estimators/speech_recognition/pytorch_deep_speech.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -583,17 +583,19 @@ def transform_model_input(
583583
# Push the sequence to device
584584
if not tensor_input:
585585
x[i] = x[i].astype(ART_NUMPY_DTYPE)
586-
x[i] = torch.tensor(x[i]).to(self._device)
586+
x_i_tensor = torch.tensor(x[i]).to(self._device)
587+
else:
588+
x_i_tensor = x[i]
587589

588590
# Set gradient computation permission
589591
if compute_gradient:
590-
x[i].requires_grad = True
592+
x_i_tensor.requires_grad = True
591593

592594
# Transform the sequence into the frequency space
593595
if tensor_input and real_lengths is not None:
594-
transformed_input = transformer(x[i][: real_lengths[i]])
596+
transformed_input = transformer(x_i_tensor[: real_lengths[i]])
595597
else:
596-
transformed_input = transformer(x[i])
598+
transformed_input = transformer(x_i_tensor)
597599

598600
spectrogram, _ = torchaudio.functional.magphase(transformed_input)
599601
spectrogram = torch.log1p(spectrogram)

0 commit comments

Comments
 (0)