Skip to content

Commit 2663def

Browse files
committed
Use 1d indexing in original layout for alphas_a
1 parent ced6124 commit 2663def

File tree

1 file changed

+13
-13
lines changed

1 file changed

+13
-13
lines changed

src/libtorchaudio/forced_align/cpu/compute.cpp

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,9 @@ void forced_align_impl(
2727
const auto L = targets.size(1);
2828
const auto S = 2 * L + 1;
2929

30-
auto alphas_a = new scalar_t[S][2]; // scalar_t is just logProbs.dtype()
31-
for (int i = 0; i < S; i++) {
32-
alphas_a[i][0] = kNegInfinity;
33-
alphas_a[i][1] = kNegInfinity;
30+
auto alphas_a = new scalar_t[2 * S]; // scalar_t is just logProbs.dtype()
31+
for (int i = 0; i < 2 * S; i++) {
32+
alphas_a[i] = kNegInfinity;
3433
}
3534

3635
torch::Tensor backPtr = torch::empty({T, S}, torch::kInt8).fill_(-1);
@@ -56,7 +55,7 @@ void forced_align_impl(
5655
auto end = (S == 1) ? 1 : 2;
5756
for (auto i = start; i < end; i++) {
5857
auto labelIdx = (i % 2 == 0) ? blank : targets_a[batchIndex][i / 2];
59-
alphas_a[i][0] = logProbs_a[batchIndex][0][labelIdx];
58+
alphas_a[i] = logProbs_a[batchIndex][0][labelIdx]; // alphas_a[0, i]
6059
}
6160
for (auto t = 1; t < T; t++) {
6261
if (T - t <= L + R) {
@@ -79,18 +78,18 @@ void forced_align_impl(
7978
auto curIdxOffset = t % 2;
8079
auto prevIdxOffset = (t - 1) % 2;
8180
for (auto j = 0; j < S; ++j) {
82-
alphas_a[j][curIdxOffset] = -std::numeric_limits<scalar_t>::infinity();
81+
alphas_a[curIdxOffset * S + j] = -std::numeric_limits<scalar_t>::infinity(); // alphas_a[curIdxOffset][j]
8382
}
8483
if (start == 0) {
85-
alphas_a[0][curIdxOffset] =
86-
alphas_a[0][prevIdxOffset] + logProbs_a[batchIndex][t][blank];
84+
alphas_a[curIdxOffset * S] =
85+
alphas_a[prevIdxOffset * S] + logProbs_a[batchIndex][t][blank];
8786
backPtr_a[t][0] = 0;
8887
startloop += 1;
8988
}
9089

9190
for (auto i = startloop; i < end; i++) {
92-
auto x0 = alphas_a[i][prevIdxOffset];
93-
auto x1 = alphas_a[i - 1][prevIdxOffset];
91+
auto x0 = alphas_a[prevIdxOffset * S + i]; // alphas_a[prevIdxOffset][i];
92+
auto x1 = alphas_a[prevIdxOffset * S + i - 1]; // alphas_a[prevIdxOffset][i - 1];
9493
auto x2 = -std::numeric_limits<scalar_t>::infinity();
9594

9695
auto labelIdx = (i % 2 == 0) ? blank : targets_a[batchIndex][i / 2];
@@ -101,7 +100,7 @@ void forced_align_impl(
101100
// (i != 1) just ensures we don't access targets[i - 2] if its i < 2
102101
if (i % 2 != 0 && i != 1 &&
103102
targets_a[batchIndex][i / 2] != targets_a[batchIndex][i / 2 - 1]) {
104-
x2 = alphas_a[i - 2][prevIdxOffset];
103+
x2 = alphas_a[prevIdxOffset * S + i - 2]; // alphas_a[prevIdxOffset][i - 2];
105104
}
106105
scalar_t result = 0.0;
107106
if (x2 > x1 && x2 > x0) {
@@ -114,11 +113,12 @@ void forced_align_impl(
114113
result = x0;
115114
backPtr_a[t][i] = 0;
116115
}
117-
alphas_a[i][curIdxOffset] = result + logProbs_a[batchIndex][t][labelIdx];
116+
alphas_a[curIdxOffset * S + i] = result + logProbs_a[batchIndex][t][labelIdx]; // alphas_a[curIdxOffset][i]
118117
}
119118
}
120119
auto idx1 = (T - 1) % 2;
121-
auto ltrIdx = alphas_a[S - 1][idx1] > alphas_a[S - 2][idx1] ? S - 1 : S - 2;
120+
auto ltrIdx = alphas_a[S * idx1 + S - 1] >
121+
alphas_a[S * idx1 + S - 2] ? S - 1 : S - 2; // alphas_a[idx1][S - 1], alphas_a[idx1][S - 2]
122122
delete[] alphas_a;
123123
// path stores the token index for each time step after force alignment.
124124
for (auto t = T - 1; t > -1; t--) {

0 commit comments

Comments
 (0)