Skip to content

Commit 77d4421

Browse files
committed
Port autograd code for rnnt
1 parent a3fe94e commit 77d4421

File tree

2 files changed

+19
-49
lines changed

2 files changed

+19
-49
lines changed

src/libtorchaudio/rnnt/autograd.cpp

Lines changed: 3 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -3,31 +3,7 @@
33
namespace torchaudio {
44
namespace rnnt {
55

6-
class RNNTLossFunction : public torch::autograd::Function<RNNTLossFunction> {
7-
public:
8-
static torch::autograd::tensor_list forward(
9-
torch::autograd::AutogradContext* ctx,
10-
torch::Tensor& logits,
11-
const torch::Tensor& targets,
12-
const torch::Tensor& logit_lengths,
13-
const torch::Tensor& target_lengths,
14-
int64_t blank,
15-
double clamp,
16-
bool fused_log_softmax = true) {
17-
torch::Tensor undef;
18-
auto result = rnnt_loss(
19-
logits,
20-
targets,
21-
logit_lengths,
22-
target_lengths,
23-
blank,
24-
clamp,
25-
fused_log_softmax);
26-
auto costs = std::get<0>(result);
27-
auto grads = std::get<1>(result).value_or(undef);
28-
ctx->save_for_backward({grads});
29-
return {costs, grads};
30-
}
6+
317

328
static torch::autograd::tensor_list backward(
339
torch::autograd::AutogradContext* ctx,
@@ -39,31 +15,10 @@ class RNNTLossFunction : public torch::autograd::Function<RNNTLossFunction> {
3915
torch::Tensor undef;
4016
return {result, undef, undef, undef, undef, undef, undef, undef};
4117
}
42-
};
43-
44-
std::tuple<torch::Tensor, std::optional<torch::Tensor>> rnnt_loss_autograd(
45-
torch::Tensor& logits,
46-
const torch::Tensor& targets,
47-
const torch::Tensor& logit_lengths,
48-
const torch::Tensor& target_lengths,
49-
int64_t blank,
50-
double clamp,
51-
bool fused_log_softmax = true) {
52-
at::AutoDispatchBelowADInplaceOrView guard;
53-
auto results = RNNTLossFunction::apply(
54-
logits,
55-
targets,
56-
logit_lengths,
57-
target_lengths,
58-
blank,
59-
clamp,
60-
fused_log_softmax);
61-
return std::make_tuple(results[0], results[1]);
6218
}
6319

64-
TORCH_LIBRARY_IMPL(torchaudio, Autograd, m) {
65-
m.impl("rnnt_loss", rnnt_loss_autograd);
20+
TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
21+
m.def("torchaudio::rnnt_loss_forward", &rnnt_loss);
6622
}
6723

68-
} // namespace rnnt
6924
} // namespace torchaudio

src/torchaudio/functional/functional.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1760,6 +1760,21 @@ def _fix_waveform_shape(
17601760
waveform_shift = waveform_shift.view(shape[:-1] + waveform_shift.shape[-1:])
17611761
return waveform_shift
17621762

1763+
class RnntLoss(torch.autograd.Function):
1764+
@staticmethod
1765+
def forward(ctx, *args):
1766+
output, saved = torch.ops.torchaudio.rnnt_loss_forward(*args)
1767+
ctx.save_for_backward(saved)
1768+
return output
1769+
1770+
@staticmethod
1771+
def backward(ctx, dy):
1772+
grad = ctx.saved_tensors[0]
1773+
grad_out = dy.view((-1, 1, 1, 1))
1774+
result = grad * grad_out;
1775+
return (result, None, None, None, None, None, None, None)
1776+
1777+
torch.ops.torchaudio.rnnt_loss_forward
17631778

17641779
def _rnnt_loss(
17651780
logits: Tensor,
@@ -1803,7 +1818,7 @@ def _rnnt_loss(
18031818
if blank < 0: # reinterpret blank index if blank < 0.
18041819
blank = logits.shape[-1] + blank
18051820

1806-
costs, _ = torch.ops.torchaudio.rnnt_loss(
1821+
costs = RnntLoss.apply(
18071822
logits=logits,
18081823
targets=targets,
18091824
logit_lengths=logit_lengths,

0 commit comments

Comments
 (0)