Skip to content

Commit 2062dc7

Browse files
committed
Make alphas_a standard C array
1 parent 6fbc710 commit 2062dc7

File tree

1 file changed

+20
-16
lines changed

1 file changed

+20
-16
lines changed

src/libtorchaudio/forced_align/cpu/compute.cpp

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
#include <torch/script.h>
22
#include <torch/torch.h>
3+
#include <torch/csrc/stable/library.h>
4+
#include <torch/csrc/stable/tensor.h>
5+
#include <torch/csrc/stable/ops.h>
6+
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
37

48
using namespace std;
59

@@ -22,17 +26,17 @@ void forced_align_impl(
2226
const auto T = logProbs.size(1);
2327
const auto L = targets.size(1);
2428
const auto S = 2 * L + 1;
25-
torch::Tensor alphas = torch::empty(
26-
{2, S},
27-
torch::TensorOptions()
28-
.device(logProbs.device())
29-
.dtype(logProbs.dtype()))
30-
.fill_(kNegInfinity);
29+
30+
auto alphas_a = new scalar_t[S][2];
31+
for (int i = 0; i < S; i++) {
32+
alphas_a[i][0] = kNegInfinity;
33+
alphas_a[i][1] = kNegInfinity;
34+
}
35+
3136
torch::Tensor backPtr = torch::empty({T, S}, torch::kInt8).fill_(-1);
3237
auto logProbs_a = logProbs.accessor<scalar_t, 3>();
3338
auto targets_a = targets.accessor<target_t, 2>();
3439
auto paths_a = paths.accessor<target_t, 2>();
35-
auto alphas_a = alphas.accessor<scalar_t, 2>();
3640
auto backPtr_a = backPtr.accessor<int8_t, 2>();
3741
auto R = 0;
3842
for (auto i = 1; i < L; i++) {
@@ -52,7 +56,7 @@ void forced_align_impl(
5256
auto end = (S == 1) ? 1 : 2;
5357
for (auto i = start; i < end; i++) {
5458
auto labelIdx = (i % 2 == 0) ? blank : targets_a[batchIndex][i / 2];
55-
alphas_a[0][i] = logProbs_a[batchIndex][0][labelIdx];
59+
alphas_a[i][0] = logProbs_a[batchIndex][0][labelIdx];
5660
}
5761
for (auto t = 1; t < T; t++) {
5862
if (T - t <= L + R) {
@@ -75,18 +79,18 @@ void forced_align_impl(
7579
auto curIdxOffset = t % 2;
7680
auto prevIdxOffset = (t - 1) % 2;
7781
for (auto j = 0; j < S; ++j) {
78-
alphas_a[curIdxOffset][j] = -std::numeric_limits<scalar_t>::infinity();
82+
alphas_a[j][curIdxOffset] = -std::numeric_limits<scalar_t>::infinity();
7983
}
8084
if (start == 0) {
81-
alphas_a[curIdxOffset][0] =
82-
alphas_a[prevIdxOffset][0] + logProbs_a[batchIndex][t][blank];
85+
alphas_a[0][curIdxOffset] =
86+
alphas_a[0][prevIdxOffset] + logProbs_a[batchIndex][t][blank];
8387
backPtr_a[t][0] = 0;
8488
startloop += 1;
8589
}
8690

8791
for (auto i = startloop; i < end; i++) {
88-
auto x0 = alphas_a[prevIdxOffset][i];
89-
auto x1 = alphas_a[prevIdxOffset][i - 1];
92+
auto x0 = alphas_a[i][prevIdxOffset];
93+
auto x1 = alphas_a[i - 1][prevIdxOffset];
9094
auto x2 = -std::numeric_limits<scalar_t>::infinity();
9195

9296
auto labelIdx = (i % 2 == 0) ? blank : targets_a[batchIndex][i / 2];
@@ -97,7 +101,7 @@ void forced_align_impl(
97101
// (i != 1) just ensures we don't access targets[i - 2] if its i < 2
98102
if (i % 2 != 0 && i != 1 &&
99103
targets_a[batchIndex][i / 2] != targets_a[batchIndex][i / 2 - 1]) {
100-
x2 = alphas_a[prevIdxOffset][i - 2];
104+
x2 = alphas_a[i - 2][prevIdxOffset];
101105
}
102106
scalar_t result = 0.0;
103107
if (x2 > x1 && x2 > x0) {
@@ -110,11 +114,11 @@ void forced_align_impl(
110114
result = x0;
111115
backPtr_a[t][i] = 0;
112116
}
113-
alphas_a[curIdxOffset][i] = result + logProbs_a[batchIndex][t][labelIdx];
117+
alphas_a[i][curIdxOffset] = result + logProbs_a[batchIndex][t][labelIdx];
114118
}
115119
}
116120
auto idx1 = (T - 1) % 2;
117-
auto ltrIdx = alphas_a[idx1][S - 1] > alphas_a[idx1][S - 2] ? S - 1 : S - 2;
121+
auto ltrIdx = alphas_a[S - 1][idx1] > alphas_a[S - 2][idx1] ? S - 1 : S - 2;
118122
// path stores the token index for each time step after force alignment.
119123
for (auto t = T - 1; t > -1; t--) {
120124
auto lbl_idx = ltrIdx % 2 == 0 ? blank : targets_a[batchIndex][ltrIdx / 2];

0 commit comments

Comments
 (0)