@@ -147,10 +147,10 @@ template <typename scalar_t>
147147const auto forced_align_int_impl = forced_align_impl<scalar_t , ScalarType::Int>;
148148
149149std::tuple<Tensor, Tensor> compute (
150- const Tensor& logProbs,
151- const Tensor& targets,
152- const Tensor& inputLengths,
153- const Tensor& targetLengths,
150+ Tensor logProbs,
151+ Tensor targets,
152+ Tensor inputLengths,
153+ Tensor targetLengths,
154154 const int64_t blank) {
155155 STD_TORCH_CHECK (logProbs.is_cpu (), " log_probs must be a CPU tensor" );
156156 STD_TORCH_CHECK (targets.is_cpu (), " targets must be a CPU tensor" );
@@ -224,24 +224,8 @@ std::tuple<Tensor, Tensor> compute(
224224 return std::make_tuple (paths, logProbs);
225225}
226226
227- void boxed_forced_align_cpu (
228- StableIValue* stack,
229- uint64_t num_args,
230- uint64_t num_outputs) {
231- STD_TORCH_CHECK (num_args == 5 , " num_args must be 5" );
232- STD_TORCH_CHECK (num_outputs == 2 , " num_outputs must be 2" );
233- std::tuple<Tensor, Tensor> res = compute (
234- /* logProbs*/ torch::stable::detail::to<Tensor>(stack[0 ]),
235- /* targets*/ torch::stable::detail::to<Tensor>(stack[1 ]),
236- /* logit_lengths*/ torch::stable::detail::to<Tensor>(stack[2 ]),
237- /* target_lengths*/ torch::stable::detail::to<Tensor>(stack[3 ]),
238- /* blank*/ float (torch::stable::detail::to<int64_t >(stack[4 ])));
239- stack[0 ] = torch::stable::detail::from (std::get<0 >(res));
240- stack[1 ] = torch::stable::detail::from (std::get<1 >(res));
241- }
242-
243227STABLE_TORCH_LIBRARY_IMPL (torchaudio, CPU, m) {
244- m.impl (" forced_align" , &boxed_forced_align_cpu );
228+ m.impl (" forced_align" , TORCH_BOX (&compute) );
245229}
246230
247231} // namespace cpu
0 commit comments