Skip to content

Commit 77fd1ad

Browse files
committed
Use stable tensors throughout forced_align code
1 parent 258ca00 commit 77fd1ad

File tree

3 files changed

+50
-34
lines changed

3 files changed

+50
-34
lines changed

src/libtorchaudio/accessor.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ class Accessor {
1515
using tensor_type = typename std::conditional<IsConst, const Tensor&, Tensor&>::type;
1616

1717
Accessor(tensor_type tensor) {
18-
data = tensor.template data_ptr<T>();
19-
for (int i = 0; i < k; i++) {
18+
data = (T*)tensor.template data_ptr();
19+
for (unsigned int i = 0; i < k; i++) {
2020
strides[i] = tensor.stride(i);
2121
}
2222
}
@@ -25,7 +25,7 @@ class Accessor {
2525
va_list args;
2626
va_start(args, k);
2727
int64_t ix = 0;
28-
for (int i = 0; i < k; i++) {
28+
for (unsigned int i = 0; i < k; i++) {
2929
ix += strides[i] * va_arg(args, int);
3030
}
3131
va_end(args);
@@ -37,7 +37,7 @@ class Accessor {
3737
va_list args;
3838
va_start(args, value);
3939
int64_t ix = 0;
40-
for (int i = 0; i < k; i++) {
40+
for (unsigned int i = 0; i < k; i++) {
4141
ix += strides[i] * va_arg(args, int);
4242
}
4343
va_end(args);

src/libtorchaudio/accessor_tests.cpp

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,15 @@
44
#include <torch/csrc/stable/tensor.h>
55
#include <torch/csrc/stable/library.h>
66

7+
namespace torchaudio {
8+
9+
namespace accessor_tests {
10+
711
using namespace std;
812
using torch::stable::Tensor;
913

1014
bool test_accessor(const Tensor tensor) {
11-
int64_t* data_ptr = tensor.template data_ptr<int64_t>();
15+
int64_t* data_ptr = (int64_t*)tensor.data_ptr();
1216
auto accessor = Accessor<3, int64_t>(tensor);
1317
for (unsigned int i = 0; i < tensor.size(0); i++) {
1418
for (unsigned int j = 0; j < tensor.size(1); j++) {
@@ -25,10 +29,18 @@ bool test_accessor(const Tensor tensor) {
2529

2630
void boxed_test_accessor(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
2731
Tensor t1(to<AtenTensorHandle>(stack[0]));
28-
auto result = compute(std::move(t1));
32+
auto result = test_accessor(std::move(t1));
2933
stack[0] = from(result);
3034
}
3135

32-
STABLE_TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
33-
m.def("torchaudio::_test_accessor", &boxed_test_accessor);
36+
TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
37+
m.def(
38+
"_test_accessor(Tensor log_probs) -> bool");
39+
}
40+
41+
STABLE_TORCH_LIBRARY_IMPL(torchaudio, CPU, m) {
42+
m.impl("torchaudio::_test_accessor", &boxed_test_accessor);
43+
}
44+
45+
}
3446
}

src/libtorchaudio/forced_align/cpu/compute.cpp

Lines changed: 30 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
77
#include <torch/csrc/inductor/aoti_torch/utils.h>
88
#include <libtorchaudio/accessor.h>
9+
#include <torch/headeronly/util/Half.h>
910

1011

1112
using namespace std;
@@ -22,7 +23,7 @@ template <typename scalar_t, typename target_t>
2223
void forced_align_impl(
2324
const Tensor logProbs,
2425
const Tensor targets,
25-
const Tensor blank,
26+
target_t blank,
2627
Tensor paths) {
2728
const scalar_t kNegInfinity = -std::numeric_limits<scalar_t>::infinity();
2829
const auto batchIndex =
@@ -143,15 +144,15 @@ std::tuple<Tensor, Tensor> compute(
143144
TORCH_CHECK(logProbs.is_cpu(), "log_probs must be a CPU tensor");
144145
TORCH_CHECK(targets.is_cpu(), "targets must be a CPU tensor");
145146
TORCH_CHECK(
146-
logProbs.device() == targets.device(),
147+
logProbs.get_device() == targets.get_device(),
147148
"log_probs and targets need to be on the same device");
148149
TORCH_CHECK(
149-
logProbs.dtype() == torch::kFloat64 ||
150-
logProbs.dtype() == torch::kFloat32 ||
151-
logProbs.dtype() == torch::kFloat16,
150+
logProbs.dtype() == aoti_torch_dtype_float64() ||
151+
logProbs.dtype() == aoti_torch_dtype_float32() ||
152+
logProbs.dtype() == aoti_torch_dtype_float16(),
152153
"log_probs must be float64, float32 or float16 (half) type");
153154
TORCH_CHECK(
154-
targets.dtype() == torch::kInt32 || targets.dtype() == torch::kInt64,
155+
targets.dtype() == aoti_torch_dtype_int32() || targets.dtype() == aoti_torch_dtype_int64(),
155156
"targets must be int32 or int64 type");
156157
TORCH_CHECK(logProbs.is_contiguous(), "log_probs must be contiguous");
157158
TORCH_CHECK(targets.is_contiguous(), "targets must be contiguous");
@@ -174,38 +175,41 @@ std::tuple<Tensor, Tensor> compute(
174175
blank >= 0 && blank < logProbs.size(-1),
175176
"blank must be within [0, num classes)");
176177

177-
TORCH_CHECK(
178-
logProbs.size(1) == at::max(inputLengths).item().toInt(),
179-
"input length mismatch");
180-
TORCH_CHECK(
181-
targets.size(1) == at::max(targetLengths).item().toInt(),
182-
"target length mismatch");
178+
// TODO: Requires port of `max` operator.
179+
// TORCH_CHECK(
180+
// logProbs.size(1) == at::max(inputLengths).item().toInt(),
181+
// "input length mismatch");
182+
// TORCH_CHECK(
183+
// targets.size(1) == at::max(targetLengths).item().toInt(),
184+
// "target length mismatch");
183185

184186
const auto B = logProbs.size(0);
185187
const auto T = logProbs.size(1);
186188

187189
int64_t paths_size[2] = {B, T};
188190
int64_t paths_stride[2] = {T, 1};
189191
AtenTensorHandle paths_h;
190-
aoti_torch_empty_strided(1, paths_size, paths_stride, targets_dtype, targets_device, targets_device_index, &paths_h);
192+
int32_t targets_device;
193+
aoti_torch_get_device_type(targets.get(), &targets_device);
194+
aoti_torch_empty_strided(1, paths_size, paths_stride, targets.dtype(), targets_device, targets.get_device(), &paths_h);
191195
auto paths = Tensor(paths_h);
192196

193197

194198
if (targets.dtype() == aoti_torch_dtype_int64()) {
195-
if (logProbs.scalar_type() == aoti_torch_dtype_float64()) {
196-
forced_align_impl<float64, int64>(logProbs, targets, blank, paths);
197-
} else if (logProbs.scalar_type() == aoti_torch_dtype_float32()) {
198-
forced_align_impl<float32, int64>(logProbs, targets, blank, paths);
199-
} else if (logProbs.scalar_type() == aoti_torch_dtype_float16()) {
200-
forced_align_impl<float16, int64>(logProbs, targets, blank, paths);
199+
if (logProbs.dtype() == aoti_torch_dtype_float64()) {
200+
forced_align_impl<double, int64_t>(logProbs, targets, blank, paths);
201+
} else if (logProbs.dtype() == aoti_torch_dtype_float32()) {
202+
forced_align_impl<float, int64_t>(logProbs, targets, blank, paths);
203+
} else if (logProbs.dtype() == aoti_torch_dtype_float16()) {
204+
forced_align_impl<c10::Half, int64_t>(logProbs, targets, blank, paths);
201205
}
202-
} else if (targets.scalar_type() == aoti_torch_dtype_int32()) {
203-
if (logProbs.scalar_type() == aoti_torch_dtype_float64()) {
204-
forced_align_impl<float64, int32>(logProbs, targets, blank, paths);
205-
} else if (logProbs.scalar_type() == aoti_torch_dtype_float32()) {
206-
forced_align_impl<float32, int32>(logProbs, targets, blank, paths);
207-
} else if (logProbs.scalar_type() == aoti_torch_dtype_float16()) {
208-
forced_align_impl<float16, int32>(logProbs, targets, blank, paths);
206+
} else if (targets.dtype() == aoti_torch_dtype_int32()) {
207+
if (logProbs.dtype() == aoti_torch_dtype_float64()) {
208+
forced_align_impl<double, int32_t>(logProbs, targets, blank, paths);
209+
} else if (logProbs.dtype() == aoti_torch_dtype_float32()) {
210+
forced_align_impl<float, int32_t>(logProbs, targets, blank, paths);
211+
} else if (logProbs.dtype() == aoti_torch_dtype_float16()) {
212+
forced_align_impl<c10::Half, int32_t>(logProbs, targets, blank, paths);
209213
}
210214
}
211215
return std::make_tuple(

0 commit comments

Comments
 (0)