3
3
namespace torchaudio {
4
4
namespace rnnt {
5
5
6
- class RNNTLossFunction : public torch ::autograd::Function<RNNTLossFunction> {
7
- public:
8
- static torch::autograd::tensor_list forward (
9
- torch::autograd::AutogradContext* ctx,
10
- torch::Tensor& logits,
11
- const torch::Tensor& targets,
12
- const torch::Tensor& logit_lengths,
13
- const torch::Tensor& target_lengths,
14
- int64_t blank,
15
- double clamp,
16
- bool fused_log_softmax = true ) {
17
- torch::Tensor undef;
18
- auto result = rnnt_loss (
19
- logits,
20
- targets,
21
- logit_lengths,
22
- target_lengths,
23
- blank,
24
- clamp,
25
- fused_log_softmax);
26
- auto costs = std::get<0 >(result);
27
- auto grads = std::get<1 >(result).value_or (undef);
28
- ctx->save_for_backward ({grads});
29
- return {costs, grads};
30
- }
6
+
31
7
32
8
static torch::autograd::tensor_list backward (
33
9
torch::autograd::AutogradContext* ctx,
@@ -39,31 +15,10 @@ class RNNTLossFunction : public torch::autograd::Function<RNNTLossFunction> {
39
15
torch::Tensor undef;
40
16
return {result, undef, undef, undef, undef, undef, undef, undef};
41
17
}
42
- };
43
-
44
- std::tuple<torch::Tensor, std::optional<torch::Tensor>> rnnt_loss_autograd (
45
- torch::Tensor& logits,
46
- const torch::Tensor& targets,
47
- const torch::Tensor& logit_lengths,
48
- const torch::Tensor& target_lengths,
49
- int64_t blank,
50
- double clamp,
51
- bool fused_log_softmax = true ) {
52
- at::AutoDispatchBelowADInplaceOrView guard;
53
- auto results = RNNTLossFunction::apply (
54
- logits,
55
- targets,
56
- logit_lengths,
57
- target_lengths,
58
- blank,
59
- clamp,
60
- fused_log_softmax);
61
- return std::make_tuple (results[0 ], results[1 ]);
62
18
}
63
19
64
- TORCH_LIBRARY_IMPL (torchaudio, Autograd , m) {
65
- m.impl ( " rnnt_loss " , rnnt_loss_autograd );
20
+ TORCH_LIBRARY_FRAGMENT (torchaudio, m) {
21
+ m.def ( " torchaudio::rnnt_loss_forward " , &rnnt_loss );
66
22
}
67
23
68
- } // namespace rnnt
69
24
} // namespace torchaudio
0 commit comments