6
6
#include < torch/csrc/inductor/aoti_torch/c/shim.h>
7
7
#include < torch/csrc/inductor/aoti_torch/utils.h>
8
8
#include < libtorchaudio/accessor.h>
9
+ #include < torch/headeronly/util/Half.h>
9
10
10
11
11
12
using namespace std ;
@@ -22,7 +23,7 @@ template <typename scalar_t, typename target_t>
22
23
void forced_align_impl (
23
24
const Tensor logProbs,
24
25
const Tensor targets,
25
- const Tensor blank,
26
+ target_t blank,
26
27
Tensor paths) {
27
28
const scalar_t kNegInfinity = -std::numeric_limits<scalar_t >::infinity ();
28
29
const auto batchIndex =
@@ -143,15 +144,15 @@ std::tuple<Tensor, Tensor> compute(
143
144
TORCH_CHECK (logProbs.is_cpu (), " log_probs must be a CPU tensor" );
144
145
TORCH_CHECK (targets.is_cpu (), " targets must be a CPU tensor" );
145
146
TORCH_CHECK (
146
- logProbs.device () == targets.device (),
147
+ logProbs.get_device () == targets.get_device (),
147
148
" log_probs and targets need to be on the same device" );
148
149
TORCH_CHECK (
149
- logProbs.dtype () == torch:: kFloat64 ||
150
- logProbs.dtype () == torch:: kFloat32 ||
151
- logProbs.dtype () == torch:: kFloat16 ,
150
+ logProbs.dtype () == aoti_torch_dtype_float64 () ||
151
+ logProbs.dtype () == aoti_torch_dtype_float32 () ||
152
+ logProbs.dtype () == aoti_torch_dtype_float16 () ,
152
153
" log_probs must be float64, float32 or float16 (half) type" );
153
154
TORCH_CHECK (
154
- targets.dtype () == torch:: kInt32 || targets.dtype () == torch:: kInt64 ,
155
+ targets.dtype () == aoti_torch_dtype_int32 () || targets.dtype () == aoti_torch_dtype_int64 () ,
155
156
" targets must be int32 or int64 type" );
156
157
TORCH_CHECK (logProbs.is_contiguous (), " log_probs must be contiguous" );
157
158
TORCH_CHECK (targets.is_contiguous (), " targets must be contiguous" );
@@ -174,38 +175,41 @@ std::tuple<Tensor, Tensor> compute(
174
175
blank >= 0 && blank < logProbs.size (-1 ),
175
176
" blank must be within [0, num classes)" );
176
177
177
- TORCH_CHECK (
178
- logProbs.size (1 ) == at::max (inputLengths).item ().toInt (),
179
- " input length mismatch" );
180
- TORCH_CHECK (
181
- targets.size (1 ) == at::max (targetLengths).item ().toInt (),
182
- " target length mismatch" );
178
+ // TODO: Requires port of `max` operator.
179
+ // TORCH_CHECK(
180
+ // logProbs.size(1) == at::max(inputLengths).item().toInt(),
181
+ // "input length mismatch");
182
+ // TORCH_CHECK(
183
+ // targets.size(1) == at::max(targetLengths).item().toInt(),
184
+ // "target length mismatch");
183
185
184
186
const auto B = logProbs.size (0 );
185
187
const auto T = logProbs.size (1 );
186
188
187
189
int64_t paths_size[2 ] = {B, T};
188
190
int64_t paths_stride[2 ] = {T, 1 };
189
191
AtenTensorHandle paths_h;
190
- aoti_torch_empty_strided (1 , paths_size, paths_stride, targets_dtype, targets_device, targets_device_index, &paths_h);
192
+ int32_t targets_device;
193
+ aoti_torch_get_device_type (targets.get (), &targets_device);
194
+ aoti_torch_empty_strided (1 , paths_size, paths_stride, targets.dtype (), targets_device, targets.get_device (), &paths_h);
191
195
auto paths = Tensor (paths_h);
192
196
193
197
194
198
if (targets.dtype () == aoti_torch_dtype_int64 ()) {
195
- if (logProbs.scalar_type () == aoti_torch_dtype_float64 ()) {
196
- forced_align_impl<float64, int64 >(logProbs, targets, blank, paths);
197
- } else if (logProbs.scalar_type () == aoti_torch_dtype_float32 ()) {
198
- forced_align_impl<float32, int64 >(logProbs, targets, blank, paths);
199
- } else if (logProbs.scalar_type () == aoti_torch_dtype_float16 ()) {
200
- forced_align_impl<float16, int64 >(logProbs, targets, blank, paths);
199
+ if (logProbs.dtype () == aoti_torch_dtype_float64 ()) {
200
+ forced_align_impl<double , int64_t >(logProbs, targets, blank, paths);
201
+ } else if (logProbs.dtype () == aoti_torch_dtype_float32 ()) {
202
+ forced_align_impl<float , int64_t >(logProbs, targets, blank, paths);
203
+ } else if (logProbs.dtype () == aoti_torch_dtype_float16 ()) {
204
+ forced_align_impl<c10::Half, int64_t >(logProbs, targets, blank, paths);
201
205
}
202
- } else if (targets.scalar_type () == aoti_torch_dtype_int32 ()) {
203
- if (logProbs.scalar_type () == aoti_torch_dtype_float64 ()) {
204
- forced_align_impl<float64, int32 >(logProbs, targets, blank, paths);
205
- } else if (logProbs.scalar_type () == aoti_torch_dtype_float32 ()) {
206
- forced_align_impl<float32, int32 >(logProbs, targets, blank, paths);
207
- } else if (logProbs.scalar_type () == aoti_torch_dtype_float16 ()) {
208
- forced_align_impl<float16, int32 >(logProbs, targets, blank, paths);
206
+ } else if (targets.dtype () == aoti_torch_dtype_int32 ()) {
207
+ if (logProbs.dtype () == aoti_torch_dtype_float64 ()) {
208
+ forced_align_impl<double , int32_t >(logProbs, targets, blank, paths);
209
+ } else if (logProbs.dtype () == aoti_torch_dtype_float32 ()) {
210
+ forced_align_impl<float , int32_t >(logProbs, targets, blank, paths);
211
+ } else if (logProbs.dtype () == aoti_torch_dtype_float16 ()) {
212
+ forced_align_impl<c10::Half, int32_t >(logProbs, targets, blank, paths);
209
213
}
210
214
}
211
215
return std::make_tuple (
0 commit comments