@@ -255,13 +255,6 @@ void forced_align_impl(
255255 std::cout << " forced_align_impl: leaving" << std::endl;
256256}
257257
258- template <typename scalar_t >
259- const auto forced_align_long_impl =
260- forced_align_impl<scalar_t , ScalarType::Long>;
261-
262- template <typename scalar_t >
263- const auto forced_align_int_impl = forced_align_impl<scalar_t , ScalarType::Int>;
264-
265258std::tuple<Tensor, Tensor> compute (
266259 Tensor logProbs,
267260 Tensor targets,
@@ -319,12 +312,12 @@ std::tuple<Tensor, Tensor> compute(
319312 THO_DISPATCH_V2 (logProbs.scalar_type (), " forced_align_impl" , AT_WRAP ([&] {
320313 if (targets.scalar_type () == ScalarType::Long) {
321314 std::cout << " forced_align: compute: 2.1" << std::endl;
322- forced_align_long_impl <scalar_t >(logProbs, targets, blank, paths);
315+ (forced_align_impl <scalar_t , ScalarType::Long >(logProbs, targets, blank, paths) );
323316 std::cout << " forced_align: compute: 2.2" << std::endl;
324317 } else {
325- std::cout << " forced_align: compute: 2.3" << std::endl;
326318 STD_TORCH_CHECK (targets.scalar_type () == ScalarType::Int, " unexpected dtype" );
327- forced_align_int_impl<scalar_t >(logProbs, targets, blank, paths);
319+ std::cout << " forced_align: compute: 2.3" << std::endl;
320+ (forced_align_impl<scalar_t , ScalarType::Int>(logProbs, targets, blank, paths));
328321 std::cout << " forced_align: compute: 2.4" << std::endl;
329322 }
330323 }), AT_EXPAND (AT_FLOATING_TYPES), ScalarType::Half);
0 commit comments