@@ -12,10 +12,10 @@ using torch::stable::Tensor;
1212
1313// Entry point into RNNT Loss
1414std::tuple<Tensor, Tensor> compute (
15- const Tensor& logits,
16- const Tensor& targets,
17- const Tensor& logit_lengths,
18- const Tensor& target_lengths,
15+ Tensor logits,
16+ Tensor targets,
17+ Tensor logit_lengths,
18+ Tensor target_lengths,
1919 int64_t blank,
2020 double clamp,
2121 bool fused_log_softmax = true ) {
@@ -147,26 +147,8 @@ std::tuple<Tensor, Tensor> compute(
147147 return std::make_tuple (costs, gradients);
148148}
149149
150- void boxed_rnnt_loss (
151- StableIValue* stack,
152- uint64_t num_args,
153- uint64_t num_outputs) {
154- STD_TORCH_CHECK (num_args == 7 , " num_args must be 7" );
155- STD_TORCH_CHECK (num_outputs == 2 , " num_outputs must be 2" );
156- std::tuple<Tensor, Tensor> res = compute (
157- /* logits*/ torch::stable::detail::to<Tensor>(stack[0 ]),
158- /* targets*/ torch::stable::detail::to<Tensor>(stack[1 ]),
159- /* logit_lengths*/ torch::stable::detail::to<Tensor>(stack[2 ]),
160- /* target_lengths*/ torch::stable::detail::to<Tensor>(stack[3 ]),
161- /* blank*/ float (torch::stable::detail::to<int64_t >(stack[4 ])),
162- /* clamp*/ torch::stable::detail::to<double >(stack[5 ]),
163- /* fused_log_softmax*/ torch::stable::detail::to<bool >(stack[6 ]));
164- stack[0 ] = torch::stable::detail::from (std::get<0 >(res));
165- stack[1 ] = torch::stable::detail::from (std::get<1 >(res));
166- }
167-
168150STABLE_TORCH_LIBRARY_IMPL (torchaudio, CPU, m) {
169- m.impl (" rnnt_loss_forward" , &boxed_rnnt_loss );
151+ m.impl (" rnnt_loss_forward" , TORCH_BOX (&compute) );
170152}
171153
172154} // namespace cpu
0 commit comments