55
66#include < cub/cub.cuh>
77#include < limits.h>
8- #include < iostream>
98
109namespace {
1110constexpr int kNumThreads =
@@ -23,9 +22,9 @@ using torch::headeronly::ScalarType;
2322
2423template <typename scalar_t , typename target_t >
2524__global__ void falign_cuda_step_kernel (
26- const PackedTensorAccessor32<scalar_t , 3 >
25+ const torchaudio:: PackedTensorAccessor32<scalar_t , 3 >
2726 logProbs_a,
28- const PackedTensorAccessor32<target_t , 2 >
27+ const torchaudio:: PackedTensorAccessor32<target_t , 2 >
2928 targets_a,
3029 const int T,
3130 const int L,
@@ -36,9 +35,9 @@ __global__ void falign_cuda_step_kernel(
3635 int start,
3736 int end,
3837 int backPtrBufferLen,
39- PackedTensorAccessor32<scalar_t , 2 >
38+ torchaudio:: PackedTensorAccessor32<scalar_t , 2 >
4039 alphas_a,
41- PackedTensorAccessor32<int8_t , 2 >
40+ torchaudio:: PackedTensorAccessor32<int8_t , 2 >
4241 backPtrBuffer_a) {
4342 scalar_t kNegInfinity = -std::numeric_limits<scalar_t >::infinity ();
4443 const int batchIndex =
@@ -120,44 +119,38 @@ void forced_align_impl(
120119 const Tensor& targets,
121120 const int64_t blank,
122121 Tensor& paths) {
123- std::cout << " forced_align_impl: entering" << std::endl;
124122 auto defaultStream = at::cuda::getCurrentCUDAStream ();
125123 auto cpuDataTranferStream = at::cuda::getStreamFromPool ();
126124 const scalar_t kNegInfinity = -std::numeric_limits<scalar_t >::infinity ();
127125 using target_t = typename std::
128126 conditional<target_scalar_type == ScalarType::Int, int , int64_t >::type;
129- auto paths_a = accessor<target_t , 2 >(paths);
127+ auto paths_a = torchaudio:: accessor<target_t , 2 >(paths);
130128 const int batchIndex =
131129 0 ; // TODO: support batch version and use the real batch index
132130 const int T = logProbs.size (1 ); // num frames
133131 const int N = logProbs.size (2 ); // alphabet size
134132 const int L = targets.size (1 ); // label length
135133 const int S = 2 * L + 1 ;
136134
137- std::cout << " forced_align_impl: 1" << std::endl;
138135 auto targetsCpu = torchaudio::stable::cpu (targets);
139136 // backPtrBuffer stores the index offset fthe best path at current position
140137 // We copy the values to CPU after running every kBackPtrBufferSize of
141138 // frames.
142- std::cout << " forced_align_impl: 2" << std::endl;
143139 Tensor backPtrBuffer = torch::stable::new_empty (logProbs, {min (kBackPtrBufferSize , T), S}, ScalarType::Char);
144140 torch::stable::fill_ (backPtrBuffer, -1 );
145141
146- std::cout << " forced_align_impl: 3" << std::endl;
147142 Tensor backPtrCpu = torch::stable::new_empty (targetsCpu, {T, S}, ScalarType::Char);
148143 torch::stable::fill_ (backPtrCpu, -1 );
149144
150145 // we store only two time frames for alphas
151146 // alphas for compute current timeframe can be computed only from previous
152147 // time frame.
153- std::cout << " forced_align_impl: 4" << std::endl;
154148 Tensor alphas = torch::stable::new_empty (logProbs, {2 , S});
155149 torch::stable::fill_ (alphas, kNegInfinity );
156150
157151 // CPU accessors
158- std::cout << " forced_align_impl: 5" << std::endl;
159- auto targetsCpu_a = accessor<target_t , 2 >(targetsCpu);
160- auto backPtrCpu_a = accessor<int8_t , 2 >(backPtrCpu);
152+ auto targetsCpu_a = torchaudio::accessor<target_t , 2 >(targetsCpu);
153+ auto backPtrCpu_a = torchaudio::accessor<int8_t , 2 >(backPtrCpu);
161154 // count the number of repeats in label
162155 int R = 0 ;
163156 for (int i = 1 ; i < L; ++i) {
@@ -177,7 +170,6 @@ void forced_align_impl(
177170 int end = (S == 1 ) ? 1 : 2 ;
178171 int backPtrBufferLen = 0 ;
179172 Tensor bufferCopy;
180- std::cout << " forced_align_impl: 6" << std::endl;
181173 for (int t = 0 ; t < T; ++t) {
182174 if (t > 0 ) {
183175 if (T - t <= L + R) {
@@ -197,11 +189,10 @@ void forced_align_impl(
197189 end = end + 1 ;
198190 }
199191 }
200- std::cout << " forced_align_impl: t=" << t << std::endl;
201192 falign_cuda_step_kernel<scalar_t , target_t >
202193 <<<1 , kNumThreads , 0 , defaultStream>>> (
203- packed_accessor32<scalar_t , 3 >(logProbs),
204- packed_accessor32<target_t , 2 >(targets),
194+ torchaudio:: packed_accessor32<scalar_t , 3 >(logProbs),
195+ torchaudio:: packed_accessor32<target_t , 2 >(targets),
205196 T,
206197 L,
207198 N,
@@ -211,13 +202,12 @@ void forced_align_impl(
211202 start,
212203 end,
213204 backPtrBufferLen,
214- packed_accessor32<scalar_t , 2 >(alphas),
215- packed_accessor32<int8_t , 2 >(backPtrBuffer));
205+ torchaudio:: packed_accessor32<scalar_t , 2 >(alphas),
206+ torchaudio:: packed_accessor32<int8_t , 2 >(backPtrBuffer));
216207 C10_CUDA_KERNEL_LAUNCH_CHECK ();
217208 ++backPtrBufferLen;
218209 if (backPtrBufferLen == kBackPtrBufferSize || t == T - 1 ) {
219210 cpuDataTranferStream.synchronize ();
220-
221211 // GPU -> GPU copy
222212 bufferCopy = torch::stable::clone (backPtrBuffer);
223213 STD_TORCH_CHECK (bufferCopy.is_contiguous (), " unexpected fail, need to implement stable::Tensor::contiguous()" )
@@ -236,23 +226,20 @@ void forced_align_impl(
236226 backPtrBufferLen = 0 ;
237227 }
238228 }
239- std::cout << " forced_align_impl: 7" << std::endl;
240229 cpuDataTranferStream.synchronize ();
241230 auto alphasCpu = torchaudio::stable::cpu (alphas);
242- auto alphasCpu_a = accessor<scalar_t , 2 >(alphasCpu);
231+ auto alphasCpu_a = torchaudio:: accessor<scalar_t , 2 >(alphasCpu);
243232 int curIdxOffset = ((T - 1 ) % 2 );
244233 int ltrIdx =
245234 alphasCpu_a[curIdxOffset][S - 1 ] > alphasCpu_a[curIdxOffset][S - 2 ]
246235 ? S - 1
247236 : S - 2 ;
248- std::cout << " forced_align_impl: 8" << std::endl;
249237 for (int t = T - 1 ; t >= 0 ; --t) {
250238 auto lbl_idx =
251239 ltrIdx % 2 == 0 ? blank : targetsCpu_a[batchIndex][ltrIdx / 2 ];
252240 paths_a[batchIndex][t] = lbl_idx;
253241 ltrIdx -= backPtrCpu_a[t][ltrIdx];
254242 }
255- std::cout << " forced_align_impl: leaving" << std::endl;
256243}
257244
258245std::tuple<Tensor, Tensor> compute (
@@ -261,7 +248,7 @@ std::tuple<Tensor, Tensor> compute(
261248 Tensor inputLengths,
262249 Tensor targetLengths,
263250 const int64_t blank) {
264- std::cout << " forced_align: compute " << std::endl;
251+
265252 STD_TORCH_CHECK (logProbs.is_cuda (), " log_probs must be a CUDA tensor" );
266253 STD_TORCH_CHECK (targets.is_cuda (), " targets must be a CUDA tensor" );
267254 STD_TORCH_CHECK (
@@ -306,30 +293,21 @@ std::tuple<Tensor, Tensor> compute(
306293
307294 auto B = logProbs.size (0 );
308295 auto T = logProbs.size (1 ); // num frames
309- std::cout << " forced_align: compute: 1 " << std::endl;
296+
310297 Tensor paths = torchaudio::stable::new_zeros (targets, {B, T}, /* dtype=*/ std::nullopt , /* layout=*/ std::nullopt , /* device=*/ torchaudio::stable::cpu_device ());
311- std::cout << " forced_align: compute: 2 " << std::endl;
298+
312299 THO_DISPATCH_V2 (logProbs.scalar_type (), " forced_align_impl" , AT_WRAP ([&] {
313300 if (targets.scalar_type () == ScalarType::Long) {
314- std::cout << " forced_align: compute: 2.1" << std::endl;
315301 (forced_align_impl<scalar_t , ScalarType::Long>(logProbs, targets, blank, paths));
316- std::cout << " forced_align: compute: 2.2" << std::endl;
317302 } else {
318- STD_TORCH_CHECK (targets.scalar_type () == ScalarType::Int, " unexpected dtype" );
319- std::cout << " forced_align: compute: 2.3" << std::endl;
320303 (forced_align_impl<scalar_t , ScalarType::Int>(logProbs, targets, blank, paths));
321- std::cout << " forced_align: compute: 2.4" << std::endl;
322- }
304+ }
323305 }), AT_EXPAND (AT_FLOATING_TYPES), ScalarType::Half);
324- std::cout << " forced_align: compute: 3" << std::endl;
325306 Tensor pathsCuda = torchaudio::stable::cuda (paths, logProbs.get_device_index ());
326- std::cout << " forced_align: compute: 4" << std::endl;
327307 return std::make_tuple (pathsCuda, logProbs);
328308}
329309
330-
331310STABLE_TORCH_LIBRARY_IMPL (torchaudio, CUDA, m) {
332- std::cout << " forced_align: library impl" << std::endl;
333311 m.impl (" forced_align" , TORCH_BOX (&compute));
334312}
335313
0 commit comments