11#include < libtorchaudio/utils.h>
2- #include < libtorchaudio/stable/TensorAccessor.h>
32#include < torch/csrc/stable/library.h>
43#include < torch/headeronly/core/Dispatch_v2.h>
54#include < torch/headeronly/core/ScalarType.h>
@@ -23,9 +22,9 @@ using torch::headeronly::ScalarType;
2322
2423template <typename scalar_t , typename target_t >
2524__global__ void falign_cuda_step_kernel (
26- const torchaudio::stable:: PackedTensorAccessor32<scalar_t , 3 , torchaudio::stable::RestrictPtrTraits >
25+ const torchaudio::PackedTensorAccessor32<scalar_t , 3 >
2726 logProbs_a,
28- const torchaudio::stable:: PackedTensorAccessor32<target_t , 2 , torchaudio::stable::RestrictPtrTraits >
27+ const torchaudio::PackedTensorAccessor32<target_t , 2 >
2928 targets_a,
3029 const int T,
3130 const int L,
@@ -36,9 +35,9 @@ __global__ void falign_cuda_step_kernel(
3635 int start,
3736 int end,
3837 int backPtrBufferLen,
39- torchaudio::stable:: PackedTensorAccessor32<scalar_t , 2 , torchaudio::stable::RestrictPtrTraits >
38+ torchaudio::PackedTensorAccessor32<scalar_t , 2 >
4039 alphas_a,
41- torchaudio::stable:: PackedTensorAccessor32<int8_t , 2 , torchaudio::stable::RestrictPtrTraits >
40+ torchaudio::PackedTensorAccessor32<int8_t , 2 >
4241 backPtrBuffer_a) {
4342 scalar_t kNegInfinity = -std::numeric_limits<scalar_t >::infinity ();
4443 const int batchIndex =
@@ -125,7 +124,7 @@ void forced_align_impl(
125124 const scalar_t kNegInfinity = -std::numeric_limits<scalar_t >::infinity ();
126125 using target_t = typename std::
127126 conditional<target_scalar_type == ScalarType::Int, int , int64_t >::type;
128- auto paths_a = torchaudio::stable:: accessor<target_t , 2 >(paths);
127+ auto paths_a = torchaudio::accessor<target_t , 2 >(paths);
129128 const int batchIndex =
130129 0 ; // TODO: support batch version and use the real batch index
131130 const int T = logProbs.size (1 ); // num frames
@@ -150,8 +149,8 @@ void forced_align_impl(
150149 torch::stable::fill_ (alphas, kNegInfinity );
151150
152151 // CPU accessors
153- auto targetsCpu_a = torchaudio::stable:: accessor<target_t , 2 >(targetsCpu);
154- auto backPtrCpu_a = torchaudio::stable:: accessor<int8_t , 2 >(backPtrCpu);
152+ auto targetsCpu_a = torchaudio::accessor<target_t , 2 >(targetsCpu);
153+ auto backPtrCpu_a = torchaudio::accessor<int8_t , 2 >(backPtrCpu);
155154 // count the number of repeats in label
156155 int R = 0 ;
157156 for (int i = 1 ; i < L; ++i) {
@@ -192,8 +191,8 @@ void forced_align_impl(
192191 }
193192 falign_cuda_step_kernel<scalar_t , target_t >
194193 <<<1 , kNumThreads , 0 , defaultStream>>> (
195- torchaudio::stable:: packed_accessor32<scalar_t , 3 , torchaudio::stable::RestrictPtrTraits >(logProbs),
196- torchaudio::stable:: packed_accessor32<target_t , 2 , torchaudio::stable::RestrictPtrTraits >(targets),
194+ torchaudio::packed_accessor32<scalar_t , 3 >(logProbs),
195+ torchaudio::packed_accessor32<target_t , 2 >(targets),
197196 T,
198197 L,
199198 N,
@@ -203,8 +202,8 @@ void forced_align_impl(
203202 start,
204203 end,
205204 backPtrBufferLen,
206- torchaudio::stable:: packed_accessor32<scalar_t , 2 , torchaudio::stable::RestrictPtrTraits >(alphas),
207- torchaudio::stable:: packed_accessor32<int8_t , 2 , torchaudio::stable::RestrictPtrTraits >(backPtrBuffer));
205+ torchaudio::packed_accessor32<scalar_t , 2 >(alphas),
206+ torchaudio::packed_accessor32<int8_t , 2 >(backPtrBuffer));
208207 C10_CUDA_KERNEL_LAUNCH_CHECK ();
209208 ++backPtrBufferLen;
210209 if (backPtrBufferLen == kBackPtrBufferSize || t == T - 1 ) {
@@ -228,9 +227,8 @@ void forced_align_impl(
228227 }
229228 }
230229 cpuDataTranferStream.synchronize ();
231-
232230 auto alphasCpu = torchaudio::stable::cpu (alphas);
233- auto alphasCpu_a = torchaudio::stable:: accessor<scalar_t , 2 >(alphasCpu);
231+ auto alphasCpu_a = torchaudio::accessor<scalar_t , 2 >(alphasCpu);
234232 int curIdxOffset = ((T - 1 ) % 2 );
235233 int ltrIdx =
236234 alphasCpu_a[curIdxOffset][S - 1 ] > alphasCpu_a[curIdxOffset][S - 2 ]
@@ -244,18 +242,11 @@ void forced_align_impl(
244242 }
245243}
246244
247- template <typename scalar_t >
248- const auto forced_align_long_impl =
249- forced_align_impl<scalar_t , ScalarType::Long>;
250-
251- template <typename scalar_t >
252- const auto forced_align_int_impl = forced_align_impl<scalar_t , ScalarType::Int>;
253-
254245std::tuple<Tensor, Tensor> compute (
255- const Tensor& logProbs,
256- const Tensor& targets,
257- const Tensor& inputLengths,
258- const Tensor& targetLengths,
246+ Tensor logProbs,
247+ Tensor targets,
248+ Tensor inputLengths,
249+ Tensor targetLengths,
259250 const int64_t blank) {
260251
261252 STD_TORCH_CHECK (logProbs.is_cuda (), " log_probs must be a CUDA tensor" );
@@ -307,31 +298,17 @@ std::tuple<Tensor, Tensor> compute(
307298
308299 THO_DISPATCH_V2 (logProbs.scalar_type (), " forced_align_impl" , AT_WRAP ([&] {
309300 if (targets.scalar_type () == ScalarType::Long) {
310- forced_align_long_impl <scalar_t >(logProbs, targets, blank, paths);
301+ (forced_align_impl <scalar_t , ScalarType::Long >(logProbs, targets, blank, paths) );
311302 } else {
312- forced_align_int_impl <scalar_t >(logProbs, targets, blank, paths);
313- }
303+ (forced_align_impl <scalar_t , ScalarType::Int >(logProbs, targets, blank, paths) );
304+ }
314305 }), AT_EXPAND (AT_FLOATING_TYPES), ScalarType::Half);
315-
316306 Tensor pathsCuda = torchaudio::stable::cuda (paths, logProbs.get_device_index ());
317307 return std::make_tuple (pathsCuda, logProbs);
318308}
319309
320- void boxed_forced_align_gpu (StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
321- STD_TORCH_CHECK (num_args == 5 , " num_args must be 5" );
322- STD_TORCH_CHECK (num_outputs == 2 , " num_outputs must be 2" );
323- std::tuple<Tensor, Tensor> res = compute (
324- /* logProbs*/ torch::stable::detail::to<Tensor>(stack[0 ]),
325- /* targets*/ torch::stable::detail::to<Tensor>(stack[1 ]),
326- /* logit_lengths*/ torch::stable::detail::to<Tensor>(stack[2 ]),
327- /* target_lengths*/ torch::stable::detail::to<Tensor>(stack[3 ]),
328- /* blank*/ float (torch::stable::detail::to<int64_t >(stack[4 ])));
329- stack[0 ] = torch::stable::detail::from (std::get<0 >(res));
330- stack[1 ] = torch::stable::detail::from (std::get<1 >(res));
331- }
332-
333310STABLE_TORCH_LIBRARY_IMPL (torchaudio, CUDA, m) {
334- m.impl (" forced_align" , &boxed_forced_align_gpu );
311+ m.impl (" forced_align" , TORCH_BOX (&compute) );
335312}
336313
337314} // namespace gpu
0 commit comments