Skip to content

Commit b2a6973

Browse files
Port autograd code for rnnt (#3970)
Co-authored-by: Sam Anklesaria <[email protected]> Co-authored-by: Nicolas Hug <[email protected]> Co-authored-by: Nicolas Hug <[email protected]>
1 parent f3e876c commit b2a6973

File tree

6 files changed

+24
-80
lines changed

6 files changed

+24
-80
lines changed

.github/scripts/unittest-linux/run_test.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,5 +30,5 @@ fi
3030

3131
(
3232
cd test
33-
pytest torchaudio_unittest -k "not backend and not /io/ and not prototype and not sox and not ffmpeg and not fairseq and not hdemucs"
33+
pytest torchaudio_unittest -k "not backend and not /io/ and not prototype and not sox and not ffmpeg and not fairseq and not hdemucs and not (torchscript and rnnt)"
3434
)

.github/workflows/unittest-linux-gpu.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ jobs:
117117
'--cov=torchaudio'
118118
"--junitxml=${RUNNER_TEST_RESULTS_DIR}/junit.xml"
119119
'--durations' '100'
120-
'-k' 'cuda or gpu'
120+
'-k' '(cuda or gpu) and not (torchscript and rnnt)'
121121
)
122122
123123
cd test

src/libtorchaudio/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ if(BUILD_RNNT)
2828
rnnt/compute_alphas.cpp
2929
rnnt/compute_betas.cpp
3030
rnnt/compute.cpp
31-
rnnt/autograd.cpp
3231
)
3332
if (USE_CUDA)
3433
list(

src/libtorchaudio/rnnt/autograd.cpp

Lines changed: 0 additions & 69 deletions
This file was deleted.

src/libtorchaudio/rnnt/compute.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,5 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
3030
"int blank,"
3131
"float clamp,"
3232
"bool fused_log_softmax) -> (Tensor, Tensor?)");
33+
m.def("torchaudio::rnnt_loss_forward", &rnnt_loss);
3334
}

src/torchaudio/functional/functional.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1760,6 +1760,19 @@ def _fix_waveform_shape(
17601760
waveform_shift = waveform_shift.view(shape[:-1] + waveform_shift.shape[-1:])
17611761
return waveform_shift
17621762

1763+
class RnntLoss(torch.autograd.Function):
1764+
@staticmethod
1765+
def forward(ctx, *args):
1766+
output, saved = torch.ops.torchaudio.rnnt_loss_forward(*args)
1767+
ctx.save_for_backward(saved)
1768+
return output
1769+
1770+
@staticmethod
1771+
def backward(ctx, dy):
1772+
grad = ctx.saved_tensors[0]
1773+
grad_out = dy.view((-1, 1, 1, 1))
1774+
result = grad * grad_out;
1775+
return (result, None, None, None, None, None, None, None)
17631776

17641777
def _rnnt_loss(
17651778
logits: Tensor,
@@ -1803,14 +1816,14 @@ def _rnnt_loss(
18031816
if blank < 0: # reinterpret blank index if blank < 0.
18041817
blank = logits.shape[-1] + blank
18051818

1806-
costs, _ = torch.ops.torchaudio.rnnt_loss(
1807-
logits=logits,
1808-
targets=targets,
1809-
logit_lengths=logit_lengths,
1810-
target_lengths=target_lengths,
1811-
blank=blank,
1812-
clamp=clamp,
1813-
fused_log_softmax=fused_log_softmax,
1819+
costs = RnntLoss.apply(
1820+
logits,
1821+
targets,
1822+
logit_lengths,
1823+
target_lengths,
1824+
blank,
1825+
clamp,
1826+
fused_log_softmax
18141827
)
18151828

18161829
if reduction == "mean":

0 commit comments

Comments
 (0)