Skip to content

Commit be13f64

Browse files
committed
WIP
1 parent 7a94b04 commit be13f64

File tree

3 files changed

+70
-36
lines changed

3 files changed

+70
-36
lines changed

src/libtorchaudio/accessor.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
11
#pragma once
22

3-
#include <torch/torch.h>
3+
#include <torch/csrc/stable/tensor.h>
44
#include <type_traits>
55
#include <cstdarg>
66

7+
using torch::stable::Tensor;
8+
79
template<unsigned int k, typename T, bool IsConst = true>
810
class Accessor {
911
int64_t strides[k];
1012
T *data;
1113

1214
public:
13-
using tensor_type = typename std::conditional<IsConst, const torch::Tensor&, torch::Tensor&>::type;
15+
using tensor_type = typename std::conditional<IsConst, const Tensor&, Tensor&>::type;
1416

1517
Accessor(tensor_type tensor) {
1618
data = tensor.template data_ptr<T>();
Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,18 @@
11
#include <libtorchaudio/accessor.h>
22
#include <cstdint>
33
#include <torch/torch.h>
4+
#include <torch/csrc/stable/tensor.h>
5+
#include <torch/csrc/stable/library.h>
46

57
using namespace std;
8+
using torch::stable::Tensor;
69

7-
bool test_accessor(const torch::Tensor& tensor) {
10+
bool test_accessor(const Tensor tensor) {
811
int64_t* data_ptr = tensor.template data_ptr<int64_t>();
912
auto accessor = Accessor<3, int64_t>(tensor);
10-
for (int i = 0; i < tensor.size(0); i++) {
11-
for (int j = 0; j < tensor.size(1); j++) {
12-
for (int k = 0; k < tensor.size(2); k++) {
13+
for (unsigned int i = 0; i < tensor.size(0); i++) {
14+
for (unsigned int j = 0; j < tensor.size(1); j++) {
15+
for (unsigned int k = 0; k < tensor.size(2); k++) {
1316
auto check = *(data_ptr++) == accessor.index(i, j, k);
1417
if (!check) {
1518
return false;
@@ -20,6 +23,12 @@ bool test_accessor(const torch::Tensor& tensor) {
2023
return true;
2124
}
2225

23-
TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
24-
m.def("torchaudio::_test_accessor", &test_accessor);
26+
void boxed_test_accessor(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
27+
Tensor t1(to<AtenTensorHandle>(stack[0]));
28+
auto result = compute(std::move(t1));
29+
stack[0] = from(result);
30+
}
31+
32+
STABLE_TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
33+
m.def("torchaudio::_test_accessor", &boxed_test_accessor);
2534
}

src/libtorchaudio/forced_align/cpu/compute.cpp

Lines changed: 51 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,17 @@ namespace torchaudio {
1414
namespace alignment {
1515
namespace cpu {
1616

17-
17+
using torch::stable::Tensor;
1818

1919
// Inspired from
2020
// https://github.com/flashlight/sequence/blob/main/flashlight/lib/sequence/criterion/cpu/ConnectionistTemporalClassificationCriterion.cpp
21-
template <typename scalar_t, at::ScalarType target_scalar_type>
21+
template <typename scalar_t, typename target_t>
2222
void forced_align_impl(
23-
const torch::Tensor& logProbs,
24-
const torch::Tensor& targets,
25-
const int64_t blank,
26-
torch::Tensor& paths) {
23+
const Tensor logProbs,
24+
const Tensor targets,
25+
const Tensor blank,
26+
Tensor paths) {
2727
const scalar_t kNegInfinity = -std::numeric_limits<scalar_t>::infinity();
28-
using target_t = typename std::
29-
conditional<target_scalar_type == torch::kInt, int, int64_t>::type;
3028
const auto batchIndex =
3129
0; // TODO: support batch version and use the real batch index
3230
const auto T = logProbs.size(1);
@@ -136,11 +134,11 @@ void forced_align_impl(
136134
}
137135
}
138136

139-
std::tuple<torch::Tensor, torch::Tensor> compute(
140-
const torch::Tensor& logProbs,
141-
const torch::Tensor& targets,
142-
const torch::Tensor& inputLengths,
143-
const torch::Tensor& targetLengths,
137+
std::tuple<Tensor, Tensor> compute(
138+
const Tensor& logProbs,
139+
const Tensor& targets,
140+
const Tensor& inputLengths,
141+
const Tensor& targetLengths,
144142
const int64_t blank) {
145143
TORCH_CHECK(logProbs.is_cpu(), "log_probs must be a CPU tensor");
146144
TORCH_CHECK(targets.is_cpu(), "targets must be a CPU tensor");
@@ -185,19 +183,31 @@ std::tuple<torch::Tensor, torch::Tensor> compute(
185183

186184
const auto B = logProbs.size(0);
187185
const auto T = logProbs.size(1);
188-
auto paths = torch::zeros(
189-
{B, T},
190-
torch::TensorOptions().device(targets.device()).dtype(targets.dtype()));
191-
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
192-
logProbs.scalar_type(), "forced_align_impl", [&] {
193-
if (targets.scalar_type() == torch::kInt64) {
194-
forced_align_impl<scalar_t, torch::kInt64>(
195-
logProbs, targets, blank, paths);
196-
} else {
197-
forced_align_impl<scalar_t, torch::kInt32>(
198-
logProbs, targets, blank, paths);
199-
}
200-
});
186+
187+
int64_t paths_size[2] = {B, T};
188+
int64_t paths_stride[2] = {T, 1};
189+
AtenTensorHandle paths_h;
190+
aoti_torch_empty_strided(1, paths_size, paths_stride, targets_dtype, targets_device, targets_device_index, &paths_h);
191+
auto paths = Tensor(paths_h);
192+
193+
194+
if (targets.scalar_type() == aoti_torch_dtype_int64()) {
195+
if (logProbs.scalar_type() == aoti_torch_dtype_float64()) {
196+
forced_align_impl<float64, int64>(logProbs, targets, blank, paths);
197+
} else if (logProbs.scalar_type() == aoti_torch_dtype_float32()) {
198+
forced_align_impl<float32, int64>(logProbs, targets, blank, paths);
199+
} else if (logProbs.scalar_type() == aoti_torch_dtype_float16()) {
200+
forced_align_impl<float16, int64>(logProbs, targets, blank, paths);
201+
}
202+
} else if (targets.scalar_type() == aoti_torch_dtype_int32()) {
203+
if (logProbs.scalar_type() == aoti_torch_dtype_float64()) {
204+
forced_align_impl<float64, int32>(logProbs, targets, blank, paths);
205+
} else if (logProbs.scalar_type() == aoti_torch_dtype_float32()) {
206+
forced_align_impl<float32, int32>(logProbs, targets, blank, paths);
207+
} else if (logProbs.scalar_type() == aoti_torch_dtype_float16()) {
208+
forced_align_impl<float16, int32>(logProbs, targets, blank, paths);
209+
}
210+
}
201211
return std::make_tuple(
202212
paths,
203213
logProbs.index(
@@ -207,8 +217,21 @@ std::tuple<torch::Tensor, torch::Tensor> compute(
207217
paths.index({0})}));
208218
}
209219

210-
TORCH_LIBRARY_IMPL(torchaudio, CPU, m) {
211-
m.impl("forced_align", &compute);
220+
221+
void boxed_compute(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
222+
Tensor t1(to<AtenTensorHandle>(stack[0]));
223+
Tensor t2(to<AtenTensorHandle>(stack[1]));
224+
Tensor t3(to<AtenTensorHandle>(stack[2]));
225+
Tensor t4(to<AtenTensorHandle>(stack[3]));
226+
int64_t blank = to<int64_t>(stack[4]);
227+
auto result = compute(
228+
std::move(t1), std::move(t2), std::move(t3), std::move(t4), blank);
229+
stack[0] = from(std::get<0>(result));
230+
stack[1] = from(std::get<1>(result));
231+
}
232+
233+
STABLE_TORCH_LIBRARY_IMPL(torchaudio, CPU, m) {
234+
m.impl("forced_align", &boxed_compute);
212235
}
213236

214237
} // namespace cpu

0 commit comments

Comments
 (0)