Skip to content

Commit 7d09bdd

Browse files
committed
A possible fix
1 parent bf219a9 commit 7d09bdd

File tree

1 file changed

+3
-10
lines changed

1 file changed

+3
-10
lines changed

src/libtorchaudio/forced_align/gpu/compute.cu

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -255,13 +255,6 @@ void forced_align_impl(
255255
std::cout << "forced_align_impl: leaving" << std::endl;
256256
}
257257

258-
template <typename scalar_t>
259-
const auto forced_align_long_impl =
260-
forced_align_impl<scalar_t, ScalarType::Long>;
261-
262-
template <typename scalar_t>
263-
const auto forced_align_int_impl = forced_align_impl<scalar_t, ScalarType::Int>;
264-
265258
std::tuple<Tensor, Tensor> compute(
266259
Tensor logProbs,
267260
Tensor targets,
@@ -319,12 +312,12 @@ std::tuple<Tensor, Tensor> compute(
319312
THO_DISPATCH_V2(logProbs.scalar_type(), "forced_align_impl", AT_WRAP([&] {
320313
if (targets.scalar_type() == ScalarType::Long) {
321314
std::cout << "forced_align: compute: 2.1" << std::endl;
322-
forced_align_long_impl<scalar_t>(logProbs, targets, blank, paths);
315+
(forced_align_impl<scalar_t, ScalarType::Long>(logProbs, targets, blank, paths));
323316
std::cout << "forced_align: compute: 2.2" << std::endl;
324317
} else {
325-
std::cout << "forced_align: compute: 2.3" << std::endl;
326318
STD_TORCH_CHECK(targets.scalar_type() == ScalarType::Int, "unexpected dtype");
327-
forced_align_int_impl<scalar_t>(logProbs, targets, blank, paths);
319+
std::cout << "forced_align: compute: 2.3" << std::endl;
320+
(forced_align_impl<scalar_t, ScalarType::Int>(logProbs, targets, blank, paths));
328321
std::cout << "forced_align: compute: 2.4" << std::endl;
329322
}
330323
}), AT_EXPAND(AT_FLOATING_TYPES), ScalarType::Half);

0 commit comments

Comments
 (0)