11#include < libtorchaudio/utils.h>
2- #include < libtorchaudio/stable/TensorAccessor.h>
32#include < torch/csrc/stable/library.h>
43#include < torch/headeronly/core/Dispatch_v2.h>
54#include < torch/headeronly/core/ScalarType.h>
@@ -23,9 +22,9 @@ using torch::headeronly::ScalarType;
2322
2423template <typename scalar_t , typename target_t >
2524__global__ void falign_cuda_step_kernel (
26- const torchaudio::stable:: PackedTensorAccessor32<scalar_t , 3 , torchaudio::stable::RestrictPtrTraits >
25+ const PackedTensorAccessor32<scalar_t , 3 >
2726 logProbs_a,
28- const torchaudio::stable:: PackedTensorAccessor32<target_t , 2 , torchaudio::stable::RestrictPtrTraits >
27+ const PackedTensorAccessor32<target_t , 2 >
2928 targets_a,
3029 const int T,
3130 const int L,
@@ -36,9 +35,9 @@ __global__ void falign_cuda_step_kernel(
3635 int start,
3736 int end,
3837 int backPtrBufferLen,
39- torchaudio::stable:: PackedTensorAccessor32<scalar_t , 2 , torchaudio::stable::RestrictPtrTraits >
38+ PackedTensorAccessor32<scalar_t , 2 >
4039 alphas_a,
41- torchaudio::stable:: PackedTensorAccessor32<int8_t , 2 , torchaudio::stable::RestrictPtrTraits >
40+ PackedTensorAccessor32<int8_t , 2 >
4241 backPtrBuffer_a) {
4342 scalar_t kNegInfinity = -std::numeric_limits<scalar_t >::infinity ();
4443 const int batchIndex =
@@ -125,7 +124,7 @@ void forced_align_impl(
125124 const scalar_t kNegInfinity = -std::numeric_limits<scalar_t >::infinity ();
126125 using target_t = typename std::
127126 conditional<target_scalar_type == ScalarType::Int, int , int64_t >::type;
128- auto paths_a = torchaudio::stable:: accessor<target_t , 2 >(paths);
127+ auto paths_a = accessor<target_t , 2 >(paths);
129128 const int batchIndex =
130129 0 ; // TODO: support batch version and use the real batch index
131130 const int T = logProbs.size (1 ); // num frames
@@ -192,8 +191,8 @@ void forced_align_impl(
192191 }
193192 falign_cuda_step_kernel<scalar_t , target_t >
194193 <<<1 , kNumThreads , 0 , defaultStream>>> (
195- torchaudio::stable:: packed_accessor32<scalar_t , 3 , torchaudio::stable::RestrictPtrTraits >(logProbs),
196- torchaudio::stable:: packed_accessor32<target_t , 2 , torchaudio::stable::RestrictPtrTraits >(targets),
194+ packed_accessor32<scalar_t , 3 >(logProbs),
195+ packed_accessor32<target_t , 2 >(targets),
197196 T,
198197 L,
199198 N,
@@ -203,12 +202,13 @@ void forced_align_impl(
203202 start,
204203 end,
205204 backPtrBufferLen,
206- torchaudio::stable:: packed_accessor32<scalar_t , 2 , torchaudio::stable::RestrictPtrTraits >(alphas),
207- torchaudio::stable:: packed_accessor32<int8_t , 2 , torchaudio::stable::RestrictPtrTraits >(backPtrBuffer));
205+ packed_accessor32<scalar_t , 2 >(alphas),
206+ packed_accessor32<int8_t , 2 >(backPtrBuffer));
208207 C10_CUDA_KERNEL_LAUNCH_CHECK ();
209208 ++backPtrBufferLen;
210209 if (backPtrBufferLen == kBackPtrBufferSize || t == T - 1 ) {
211210 cpuDataTranferStream.synchronize ();
211+
212212 // GPU -> GPU copy
213213 bufferCopy = torch::stable::clone (backPtrBuffer);
214214 STD_TORCH_CHECK (bufferCopy.is_contiguous (), " unexpected fail, need to implement stable::Tensor::contiguous()" )
@@ -228,7 +228,6 @@ void forced_align_impl(
228228 }
229229 }
230230 cpuDataTranferStream.synchronize ();
231-
232231 auto alphasCpu = torchaudio::stable::cpu (alphas);
233232 auto alphasCpu_a = torchaudio::stable::accessor<scalar_t , 2 >(alphasCpu);
234233 int curIdxOffset = ((T - 1 ) % 2 );
@@ -252,10 +251,10 @@ template <typename scalar_t>
252251const auto forced_align_int_impl = forced_align_impl<scalar_t , ScalarType::Int>;
253252
254253std::tuple<Tensor, Tensor> compute (
255- const Tensor& logProbs,
256- const Tensor& targets,
257- const Tensor& inputLengths,
258- const Tensor& targetLengths,
254+ Tensor logProbs,
255+ Tensor targets,
256+ Tensor inputLengths,
257+ Tensor targetLengths,
259258 const int64_t blank) {
260259
261260 STD_TORCH_CHECK (logProbs.is_cuda (), " log_probs must be a CUDA tensor" );
@@ -317,21 +316,9 @@ std::tuple<Tensor, Tensor> compute(
317316 return std::make_tuple (pathsCuda, logProbs);
318317}
319318
320- void boxed_forced_align_gpu (StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
321- STD_TORCH_CHECK (num_args == 5 , " num_args must be 5" );
322- STD_TORCH_CHECK (num_outputs == 2 , " num_outputs must be 2" );
323- std::tuple<Tensor, Tensor> res = compute (
324- /* logProbs*/ torch::stable::detail::to<Tensor>(stack[0 ]),
325- /* targets*/ torch::stable::detail::to<Tensor>(stack[1 ]),
326- /* logit_lengths*/ torch::stable::detail::to<Tensor>(stack[2 ]),
327- /* target_lengths*/ torch::stable::detail::to<Tensor>(stack[3 ]),
328- /* blank*/ float (torch::stable::detail::to<int64_t >(stack[4 ])));
329- stack[0 ] = torch::stable::detail::from (std::get<0 >(res));
330- stack[1 ] = torch::stable::detail::from (std::get<1 >(res));
331- }
332319
333320STABLE_TORCH_LIBRARY_IMPL (torchaudio, CUDA, m) {
334- m.impl (" forced_align" , &boxed_forced_align_gpu );
321+ m.impl (" forced_align" , TORCH_BOX (&compute) );
335322}
336323
337324} // namespace gpu
0 commit comments