1
1
#include < libtorchaudio/utils.h>
2
+ #include < libtorchaudio/stable/TensorAccessor.h>
3
+ #include < libtorchaudio/stable/dispatch.h>
2
4
#include < torch/csrc/stable/library.h>
3
5
4
6
#include < cub/cub.cuh>
@@ -20,9 +22,9 @@ using torch::headeronly::ScalarType;
20
22
21
23
template <typename scalar_t , typename target_t >
22
24
__global__ void falign_cuda_step_kernel (
23
- const at:: PackedTensorAccessor32<scalar_t , 3 , at ::RestrictPtrTraits>
25
+ const torchaudio::stable:: PackedTensorAccessor32<scalar_t , 3 , torchaudio::stable ::RestrictPtrTraits>
24
26
logProbs_a,
25
- const at:: PackedTensorAccessor32<target_t , 2 , at ::RestrictPtrTraits>
27
+ const torchaudio::stable:: PackedTensorAccessor32<target_t , 2 , torchaudio::stable ::RestrictPtrTraits>
26
28
targets_a,
27
29
const int T,
28
30
const int L,
@@ -33,9 +35,9 @@ __global__ void falign_cuda_step_kernel(
33
35
int start,
34
36
int end,
35
37
int backPtrBufferLen,
36
- at:: PackedTensorAccessor32<scalar_t , 2 , at ::RestrictPtrTraits>
38
+ torchaudio::stable:: PackedTensorAccessor32<scalar_t , 2 , torchaudio::stable ::RestrictPtrTraits>
37
39
alphas_a,
38
- at:: PackedTensorAccessor32<int8_t , 2 , at ::RestrictPtrTraits>
40
+ torchaudio::stable:: PackedTensorAccessor32<int8_t , 2 , torchaudio::stable ::RestrictPtrTraits>
39
41
backPtrBuffer_a) {
40
42
scalar_t kNegInfinity = -std::numeric_limits<scalar_t >::infinity ();
41
43
const int batchIndex =
@@ -122,15 +124,15 @@ void forced_align_impl(
122
124
const scalar_t kNegInfinity = -std::numeric_limits<scalar_t >::infinity ();
123
125
using target_t = typename std::
124
126
conditional<target_scalar_type == ScalarType::Int, int , int64_t >::type;
125
- auto paths_a = paths. accessor <target_t , 2 >();
127
+ auto paths_a = torchaudio::stable:: accessor<target_t , 2 >(paths );
126
128
const int batchIndex =
127
129
0 ; // TODO: support batch version and use the real batch index
128
130
const int T = logProbs.size (1 ); // num frames
129
131
const int N = logProbs.size (2 ); // alphabet size
130
132
const int L = targets.size (1 ); // label length
131
133
const int S = 2 * L + 1 ;
132
134
133
- auto targetsCpu = torch ::stable::cpu (targets);
135
+ auto targetsCpu = torchaudio ::stable::cpu (targets);
134
136
// backPtrBuffer stores the index offset fthe best path at current position
135
137
// We copy the values to CPU after running every kBackPtrBufferSize of
136
138
// frames.
@@ -147,8 +149,8 @@ void forced_align_impl(
147
149
torch::stable::fill_ (alphas, kNegInfinity );
148
150
149
151
// CPU accessors
150
- auto targetsCpu_a = targetsCpu. accessor <target_t , 2 >();
151
- auto backPtrCpu_a = backPtrCpu. accessor <int8_t , 2 >();
152
+ auto targetsCpu_a = torchaudio::stable:: accessor<target_t , 2 >(targetsCpu );
153
+ auto backPtrCpu_a = torchaudio::stable:: accessor<int8_t , 2 >(backPtrCpu );
152
154
// count the number of repeats in label
153
155
int R = 0 ;
154
156
for (int i = 1 ; i < L; ++i) {
@@ -189,8 +191,8 @@ void forced_align_impl(
189
191
}
190
192
falign_cuda_step_kernel<scalar_t , target_t >
191
193
<<<1 , kNumThreads , 0 , defaultStream>>> (
192
- logProbs. packed_accessor32 <scalar_t , 3 , at:: RestrictPtrTraits>(),
193
- targets. packed_accessor32 <target_t , 2 , at:: RestrictPtrTraits>(),
194
+ torchaudio::stable:: packed_accessor32<scalar_t , 3 , torchaudio::stable:: RestrictPtrTraits>(logProbs ),
195
+ torchaudio::stable:: packed_accessor32<target_t , 2 , torchaudio::stable:: RestrictPtrTraits>(targets ),
194
196
T,
195
197
L,
196
198
N,
@@ -200,15 +202,14 @@ void forced_align_impl(
200
202
start,
201
203
end,
202
204
backPtrBufferLen,
203
- alphas.packed_accessor32 <scalar_t , 2 , at::RestrictPtrTraits>(),
204
- backPtrBuffer
205
- .packed_accessor32 <int8_t , 2 , at::RestrictPtrTraits>());
205
+ torchaudio::stable::packed_accessor32<scalar_t , 2 , torchaudio::stable::RestrictPtrTraits>(alphas),
206
+ torchaudio::stable::packed_accessor32<int8_t , 2 , torchaudio::stable::RestrictPtrTraits>(backPtrBuffer));
206
207
C10_CUDA_KERNEL_LAUNCH_CHECK ();
207
208
++backPtrBufferLen;
208
209
if (backPtrBufferLen == kBackPtrBufferSize || t == T - 1 ) {
209
210
cpuDataTranferStream.synchronize ();
210
211
// GPU -> GPU copy
211
- bufferCopy = backPtrBuffer. clone ();
212
+ bufferCopy = torchaudio::stable:: clone (backPtrBuffer );
212
213
STD_TORCH_CHECK (bufferCopy.is_contiguous (), " unexpected fail, need to implement stable::Tensor::contiguous()" )
213
214
defaultStream.synchronize ();
214
215
at::cuda::setCurrentCUDAStream (cpuDataTranferStream);
@@ -227,8 +228,8 @@ void forced_align_impl(
227
228
}
228
229
cpuDataTranferStream.synchronize ();
229
230
230
- auto alphasCpu = torch ::stable::cpu (alphas);
231
- auto alphasCpu_a = alphasCpu. accessor <scalar_t , 2 >();
231
+ auto alphasCpu = torchaudio ::stable::cpu (alphas);
232
+ auto alphasCpu_a = torchaudio::stable:: accessor<scalar_t , 2 >(alphasCpu );
232
233
int curIdxOffset = ((T - 1 ) % 2 );
233
234
int ltrIdx =
234
235
alphasCpu_a[curIdxOffset][S - 1 ] > alphasCpu_a[curIdxOffset][S - 2 ]
@@ -294,50 +295,20 @@ std::tuple<Tensor, Tensor> compute(
294
295
auto B = logProbs.size (0 );
295
296
auto T = logProbs.size (1 ); // num frames
296
297
297
- Tensor paths = torch::stable::new_empty (targets, {B, T}, std::nullopt , aoti_torch_device_type_cpu ());
298
- torch::stable::zero_ (paths);
298
+ Tensor paths = torchaudio::stable::new_zeros (targets, {B, T}, /* dtype=*/ std::nullopt , /* layout=*/ std::nullopt , /* device=*/ torchaudio::stable::cpu_device ());
299
299
300
- switch (logProbs.scalar_type ()) {
301
- case ScalarType::Double: {
302
- if (targets.scalar_type () == ScalarType::Long) {
303
- forced_align_impl<double , ScalarType::Long>(logProbs, targets, blank, paths);
304
- } else if (targets.scalar_type () == ScalarType::Int) {
305
- forced_align_impl<double , ScalarType::Int>(logProbs, targets, blank, paths);
306
- } else {
307
- STD_TORCH_CHECK (false , " unreachable" );
308
- }
309
- break ;
310
- }
311
- case ScalarType::Float: {
312
- if (targets.scalar_type () == ScalarType::Long) {
313
- forced_align_impl<float , ScalarType::Long>(logProbs, targets, blank, paths);
314
- } else if (targets.scalar_type () == ScalarType::Int) {
315
- forced_align_impl<float , ScalarType::Int>(logProbs, targets, blank, paths);
316
- } else {
317
- STD_TORCH_CHECK (false , " unreachable" );
318
- }
319
- break ;
320
- }
321
- case ScalarType::Half: {
322
- if (targets.scalar_type () == ScalarType::Long) {
323
- forced_align_impl<c10::Half, ScalarType::Long>(logProbs, targets, blank, paths);
324
- } else if (targets.scalar_type () == ScalarType::Int) {
325
- forced_align_impl<c10::Half, ScalarType::Int>(logProbs, targets, blank, paths);
326
- } else {
327
- STD_TORCH_CHECK (false , " unreachable" );
328
- }
329
- break ;
330
- }
331
- default : {
332
- STD_TORCH_CHECK (false , " unreachable" );
333
- }
334
- };
335
- Tensor pathsCuda = torch::stable::new_empty (paths,
336
- torchaudio::util::sizes (paths),
337
- std::nullopt ,
338
- aoti_torch_device_type_cuda (),
339
- logProbs.get_device_index ());
340
- torch::stable::copy_ (pathsCuda, paths);
300
+ STABLE_DISPATCH_FLOATING_TYPES_AND_HALF (
301
+ logProbs.scalar_type (), " forced_align_impl" , [&] {
302
+ if (targets.scalar_type () == ScalarType::Long) {
303
+ forced_align_impl<scalar_t , ScalarType::Long>(
304
+ logProbs, targets, blank, paths);
305
+ } else {
306
+ forced_align_impl<scalar_t , ScalarType::Int>(
307
+ logProbs, targets, blank, paths);
308
+ }
309
+ });
310
+
311
+ Tensor pathsCuda = torchaudio::stable::cuda (paths, logProbs.get_device_index ());
341
312
return std::make_tuple (pathsCuda, logProbs);
342
313
}
343
314
0 commit comments