Skip to content

Commit 5fa467d

Browse files
committed
Merge branch 'stable_forced_align' into forced_align_backptr
2 parents 71ce212 + 2663def commit 5fa467d

File tree

1 file changed

+13
-13
lines changed

1 file changed

+13
-13
lines changed

src/libtorchaudio/forced_align/cpu/compute.cpp

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,9 @@ void forced_align_impl(
2727
const auto L = targets.size(1);
2828
const auto S = 2 * L + 1;
2929

30-
auto alphas_a = new scalar_t[S][2]; // scalar_t is just logProbs.dtype()
31-
for (int i = 0; i < S; i++) {
32-
alphas_a[i][0] = kNegInfinity;
33-
alphas_a[i][1] = kNegInfinity;
30+
auto alphas_a = new scalar_t[2 * S]; // scalar_t is just logProbs.dtype()
31+
for (int i = 0; i < 2 * S; i++) {
32+
alphas_a[i] = kNegInfinity;
3433
}
3534

3635
auto backPtr_a = new int8_t[T * S];
@@ -59,7 +58,7 @@ void forced_align_impl(
5958
auto end = (S == 1) ? 1 : 2;
6059
for (auto i = start; i < end; i++) {
6160
auto labelIdx = (i % 2 == 0) ? blank : targets_a[batchIndex][i / 2];
62-
alphas_a[i][0] = logProbs_a[batchIndex][0][labelIdx];
61+
alphas_a[i] = logProbs_a[batchIndex][0][labelIdx]; // alphas_a[0, i]
6362
}
6463
for (auto t = 1; t < T; t++) {
6564
if (T - t <= L + R) {
@@ -82,18 +81,18 @@ void forced_align_impl(
8281
auto curIdxOffset = t % 2;
8382
auto prevIdxOffset = (t - 1) % 2;
8483
for (auto j = 0; j < S; ++j) {
85-
alphas_a[j][curIdxOffset] = -std::numeric_limits<scalar_t>::infinity();
84+
alphas_a[curIdxOffset * S + j] = -std::numeric_limits<scalar_t>::infinity(); // alphas_a[curIdxOffset][j]
8685
}
8786
if (start == 0) {
88-
alphas_a[0][curIdxOffset] =
89-
alphas_a[0][prevIdxOffset] + logProbs_a[batchIndex][t][blank];
87+
alphas_a[curIdxOffset * S] =
88+
alphas_a[prevIdxOffset * S] + logProbs_a[batchIndex][t][blank]; // alphas_a[curIdxOffset][0]
9089
backPtr_a[S * t] = 0; // backPtr_a[t][0] = 0
9190
startloop += 1;
9291
}
9392

9493
for (auto i = startloop; i < end; i++) {
95-
auto x0 = alphas_a[i][prevIdxOffset];
96-
auto x1 = alphas_a[i - 1][prevIdxOffset];
94+
auto x0 = alphas_a[prevIdxOffset * S + i]; // alphas_a[prevIdxOffset][i];
95+
auto x1 = alphas_a[prevIdxOffset * S + i - 1]; // alphas_a[prevIdxOffset][i - 1];
9796
auto x2 = -std::numeric_limits<scalar_t>::infinity();
9897

9998
auto labelIdx = (i % 2 == 0) ? blank : targets_a[batchIndex][i / 2];
@@ -104,7 +103,7 @@ void forced_align_impl(
104103
// (i != 1) just ensures we don't access targets[i - 2] if its i < 2
105104
if (i % 2 != 0 && i != 1 &&
106105
targets_a[batchIndex][i / 2] != targets_a[batchIndex][i / 2 - 1]) {
107-
x2 = alphas_a[i - 2][prevIdxOffset];
106+
x2 = alphas_a[prevIdxOffset * S + i - 2]; // alphas_a[prevIdxOffset][i - 2];
108107
}
109108
scalar_t result = 0.0;
110109
if (x2 > x1 && x2 > x0) {
@@ -117,11 +116,12 @@ void forced_align_impl(
117116
result = x0;
118117
backPtr_a[t * S + i] = 0; // backPtr_a[t][i] = 0
119118
}
120-
alphas_a[i][curIdxOffset] = result + logProbs_a[batchIndex][t][labelIdx];
119+
alphas_a[curIdxOffset * S + i] = result + logProbs_a[batchIndex][t][labelIdx]; // alphas_a[curIdxOffset][i]
121120
}
122121
}
123122
auto idx1 = (T - 1) % 2;
124-
auto ltrIdx = alphas_a[S - 1][idx1] > alphas_a[S - 2][idx1] ? S - 1 : S - 2;
123+
auto ltrIdx = alphas_a[S * idx1 + S - 1] >
124+
alphas_a[S * idx1 + S - 2] ? S - 1 : S - 2; // alphas_a[idx1][S - 1], alphas_a[idx1][S - 2]
125125
delete[] alphas_a;
126126
// path stores the token index for each time step after force alignment.
127127
for (auto t = T - 1; t > -1; t--) {

0 commit comments

Comments
 (0)