Skip to content

Commit a232442

Browse files
committed
Use TORCH_BOX in forced_align/cpu.
1 parent 2b6e8ee commit a232442

File tree

3 files changed

+17
-23
lines changed

3 files changed

+17
-23
lines changed

src/libtorchaudio/forced_align/compute.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,9 @@
22

33
STABLE_TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
44
m.def(
5-
"forced_align(Tensor log_probs, Tensor targets, Tensor input_lengths, Tensor target_lengths, int blank) -> (Tensor, Tensor)");
5+
"forced_align(Tensor log_probs,"
6+
"Tensor targets,"
7+
"Tensor input_lengths,"
8+
"Tensor target_lengths,"
9+
"int blank) -> (Tensor, Tensor)");
610
}

src/libtorchaudio/forced_align/cpu/compute.cpp

Lines changed: 5 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -147,10 +147,10 @@ template <typename scalar_t>
147147
const auto forced_align_int_impl = forced_align_impl<scalar_t, ScalarType::Int>;
148148

149149
std::tuple<Tensor, Tensor> compute(
150-
const Tensor& logProbs,
151-
const Tensor& targets,
152-
const Tensor& inputLengths,
153-
const Tensor& targetLengths,
150+
Tensor logProbs,
151+
Tensor targets,
152+
Tensor inputLengths,
153+
Tensor targetLengths,
154154
const int64_t blank) {
155155
STD_TORCH_CHECK(logProbs.is_cpu(), "log_probs must be a CPU tensor");
156156
STD_TORCH_CHECK(targets.is_cpu(), "targets must be a CPU tensor");
@@ -224,24 +224,8 @@ std::tuple<Tensor, Tensor> compute(
224224
return std::make_tuple(paths, logProbs);
225225
}
226226

227-
void boxed_forced_align_cpu(
228-
StableIValue* stack,
229-
uint64_t num_args,
230-
uint64_t num_outputs) {
231-
STD_TORCH_CHECK(num_args == 5, "num_args must be 5");
232-
STD_TORCH_CHECK(num_outputs == 2, "num_outputs must be 2");
233-
std::tuple<Tensor, Tensor> res = compute(
234-
/*logProbs*/ torch::stable::detail::to<Tensor>(stack[0]),
235-
/*targets*/ torch::stable::detail::to<Tensor>(stack[1]),
236-
/*logit_lengths*/ torch::stable::detail::to<Tensor>(stack[2]),
237-
/*target_lengths*/ torch::stable::detail::to<Tensor>(stack[3]),
238-
/*blank*/ float(torch::stable::detail::to<int64_t>(stack[4])));
239-
stack[0] = torch::stable::detail::from(std::get<0>(res));
240-
stack[1] = torch::stable::detail::from(std::get<1>(res));
241-
}
242-
243227
STABLE_TORCH_LIBRARY_IMPL(torchaudio, CPU, m) {
244-
m.impl("forced_align", &boxed_forced_align_cpu);
228+
m.impl("forced_align", TORCH_BOX(&compute));
245229
}
246230

247231
} // namespace cpu

src/libtorchaudio/forced_align/gpu/compute.cu

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -318,13 +318,19 @@ std::tuple<Tensor, Tensor> compute(
318318
std::cout << "forced_align: compute: 2" << std::endl;
319319
THO_DISPATCH_V2(logProbs.scalar_type(), "forced_align_impl", AT_WRAP([&] {
320320
if (targets.scalar_type() == ScalarType::Long) {
321+
std::cout << "forced_align: compute: 2.1" << std::endl;
321322
forced_align_long_impl<scalar_t>(logProbs, targets, blank, paths);
323+
std::cout << "forced_align: compute: 2.2" << std::endl;
322324
} else {
325+
std::cout << "forced_align: compute: 2.3" << std::endl;
326+
STD_TORCH_CHECK(targets.scalar_type() == ScalarType::Int, "unexpected dtype");
323327
forced_align_int_impl<scalar_t>(logProbs, targets, blank, paths);
328+
std::cout << "forced_align: compute: 2.4" << std::endl;
324329
}
325330
}), AT_EXPAND(AT_FLOATING_TYPES), ScalarType::Half);
326-
331+
std::cout << "forced_align: compute: 3" << std::endl;
327332
Tensor pathsCuda = torchaudio::stable::cuda(paths, logProbs.get_device_index());
333+
std::cout << "forced_align: compute: 4" << std::endl;
328334
return std::make_tuple(pathsCuda, logProbs);
329335
}
330336

0 commit comments

Comments
 (0)