@@ -14,19 +14,17 @@ namespace torchaudio {
14
14
namespace alignment {
15
15
namespace cpu {
16
16
17
-
17
+ using torch::stable::Tensor;
18
18
19
19
// Inspired from
20
20
// https://github.com/flashlight/sequence/blob/main/flashlight/lib/sequence/criterion/cpu/ConnectionistTemporalClassificationCriterion.cpp
21
- template <typename scalar_t , at::ScalarType target_scalar_type >
21
+ template <typename scalar_t , typename target_t >
22
22
void forced_align_impl (
23
- const torch:: Tensor& logProbs,
24
- const torch:: Tensor& targets,
25
- const int64_t blank,
26
- torch:: Tensor& paths) {
23
+ const Tensor logProbs,
24
+ const Tensor targets,
25
+ const Tensor blank,
26
+ Tensor paths) {
27
27
const scalar_t kNegInfinity = -std::numeric_limits<scalar_t >::infinity ();
28
- using target_t = typename std::
29
- conditional<target_scalar_type == torch::kInt , int , int64_t >::type;
30
28
const auto batchIndex =
31
29
0 ; // TODO: support batch version and use the real batch index
32
30
const auto T = logProbs.size (1 );
@@ -136,11 +134,11 @@ void forced_align_impl(
136
134
}
137
135
}
138
136
139
- std::tuple<torch:: Tensor, torch:: Tensor> compute (
140
- const torch:: Tensor& logProbs,
141
- const torch:: Tensor& targets,
142
- const torch:: Tensor& inputLengths,
143
- const torch:: Tensor& targetLengths,
137
+ std::tuple<Tensor, Tensor> compute (
138
+ const Tensor& logProbs,
139
+ const Tensor& targets,
140
+ const Tensor& inputLengths,
141
+ const Tensor& targetLengths,
144
142
const int64_t blank) {
145
143
TORCH_CHECK (logProbs.is_cpu (), " log_probs must be a CPU tensor" );
146
144
TORCH_CHECK (targets.is_cpu (), " targets must be a CPU tensor" );
@@ -185,19 +183,31 @@ std::tuple<torch::Tensor, torch::Tensor> compute(
185
183
186
184
const auto B = logProbs.size (0 );
187
185
const auto T = logProbs.size (1 );
188
- auto paths = torch::zeros (
189
- {B, T},
190
- torch::TensorOptions ().device (targets.device ()).dtype (targets.dtype ()));
191
- AT_DISPATCH_FLOATING_TYPES_AND_HALF (
192
- logProbs.scalar_type (), " forced_align_impl" , [&] {
193
- if (targets.scalar_type () == torch::kInt64 ) {
194
- forced_align_impl<scalar_t , torch::kInt64 >(
195
- logProbs, targets, blank, paths);
196
- } else {
197
- forced_align_impl<scalar_t , torch::kInt32 >(
198
- logProbs, targets, blank, paths);
199
- }
200
- });
186
+
187
+ int64_t paths_size[2 ] = {B, T};
188
+ int64_t paths_stride[2 ] = {T, 1 };
189
+ AtenTensorHandle paths_h;
190
+ aoti_torch_empty_strided (1 , paths_size, paths_stride, targets_dtype, targets_device, targets_device_index, &paths_h);
191
+ auto paths = Tensor (paths_h);
192
+
193
+
194
+ if (targets.scalar_type () == aoti_torch_dtype_int64 ()) {
195
+ if (logProbs.scalar_type () == aoti_torch_dtype_float64 ()) {
196
+ forced_align_impl<float64, int64>(logProbs, targets, blank, paths);
197
+ } else if (logProbs.scalar_type () == aoti_torch_dtype_float32 ()) {
198
+ forced_align_impl<float32, int64>(logProbs, targets, blank, paths);
199
+ } else if (logProbs.scalar_type () == aoti_torch_dtype_float16 ()) {
200
+ forced_align_impl<float16, int64>(logProbs, targets, blank, paths);
201
+ }
202
+ } else if (targets.scalar_type () == aoti_torch_dtype_int32 ()) {
203
+ if (logProbs.scalar_type () == aoti_torch_dtype_float64 ()) {
204
+ forced_align_impl<float64, int32>(logProbs, targets, blank, paths);
205
+ } else if (logProbs.scalar_type () == aoti_torch_dtype_float32 ()) {
206
+ forced_align_impl<float32, int32>(logProbs, targets, blank, paths);
207
+ } else if (logProbs.scalar_type () == aoti_torch_dtype_float16 ()) {
208
+ forced_align_impl<float16, int32>(logProbs, targets, blank, paths);
209
+ }
210
+ }
201
211
return std::make_tuple (
202
212
paths,
203
213
logProbs.index (
@@ -207,8 +217,21 @@ std::tuple<torch::Tensor, torch::Tensor> compute(
207
217
paths.index ({0 })}));
208
218
}
209
219
210
- TORCH_LIBRARY_IMPL (torchaudio, CPU, m) {
211
- m.impl (" forced_align" , &compute);
220
+
221
+ void boxed_compute (StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
222
+ Tensor t1 (to<AtenTensorHandle>(stack[0 ]));
223
+ Tensor t2 (to<AtenTensorHandle>(stack[1 ]));
224
+ Tensor t3 (to<AtenTensorHandle>(stack[2 ]));
225
+ Tensor t4 (to<AtenTensorHandle>(stack[3 ]));
226
+ int64_t blank = to<int64_t >(stack[4 ]);
227
+ auto result = compute (
228
+ std::move (t1), std::move (t2), std::move (t3), std::move (t4), blank);
229
+ stack[0 ] = from (std::get<0 >(result));
230
+ stack[1 ] = from (std::get<1 >(result));
231
+ }
232
+
233
+ STABLE_TORCH_LIBRARY_IMPL (torchaudio, CPU, m) {
234
+ m.impl (" forced_align" , &boxed_compute);
212
235
}
213
236
214
237
} // namespace cpu
0 commit comments