Skip to content

Commit eb50150

Browse files
committed
Merge branch 'forced_align_backptr' into forced_align_accessors
2 parents 77fd1ad + 71ce212 commit eb50150

File tree

1 file changed

+12
-5
lines changed

1 file changed

+12
-5
lines changed

src/libtorchaudio/forced_align/cpu/compute.cpp

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ void forced_align_impl(
3232
const auto L = targets.size(1);
3333
const auto S = 2 * L + 1;
3434

35-
auto alphas_a = new scalar_t[S][2];
35+
auto alphas_a = new scalar_t[S][2]; // scalar_t is just logProbs.dtype()
3636
for (int i = 0; i < S; i++) {
3737
alphas_a[i][0] = kNegInfinity;
3838
alphas_a[i][1] = kNegInfinity;
@@ -91,8 +91,13 @@ void forced_align_impl(
9191
}
9292
if (start == 0) {
9393
alphas_a[0][curIdxOffset] =
94+
<<<<<<< HEAD
9495
alphas_a[0][prevIdxOffset] + logProbs_a.index(batchIndex, t, blank);
9596
backPtr_a[S * t] = 0;
97+
=======
98+
alphas_a[0][prevIdxOffset] + logProbs_a[batchIndex][t][blank];
99+
backPtr_a[S * t] = 0; // backPtr_a[t][0] = 0
100+
>>>>>>> forced_align_backptr
96101
startloop += 1;
97102
}
98103

@@ -114,25 +119,27 @@ void forced_align_impl(
114119
scalar_t result = 0.0;
115120
if (x2 > x1 && x2 > x0) {
116121
result = x2;
117-
backPtr_a[t * S + i] = 2;
122+
backPtr_a[t * S + i] = 2; // backPtr_a[t][i] = 2
118123
} else if (x1 > x0 && x1 > x2) {
119124
result = x1;
120-
backPtr_a[t * S + i] = 1;
125+
backPtr_a[t * S + i] = 1; // backPtr_a[t][i] = 1
121126
} else {
122127
result = x0;
123-
backPtr_a[t * S + i] = 0;
128+
backPtr_a[t * S + i] = 0; // backPtr_a[t][i] = 0
124129
}
125130
alphas_a[i][curIdxOffset] = result + logProbs_a.index(batchIndex, t, labelIdx);
126131
}
127132
}
128133
auto idx1 = (T - 1) % 2;
129134
auto ltrIdx = alphas_a[S - 1][idx1] > alphas_a[S - 2][idx1] ? S - 1 : S - 2;
135+
delete[] alphas_a;
130136
// path stores the token index for each time step after force alignment.
131137
for (auto t = T - 1; t > -1; t--) {
132138
auto lbl_idx = ltrIdx % 2 == 0 ? blank : targets_a.index(batchIndex, ltrIdx / 2);
133139
paths_a.set_index(lbl_idx, batchIndex, t);
134-
ltrIdx -= backPtr_a[t * S + ltrIdx];
140+
ltrIdx -= backPtr_a[t * S + ltrIdx]; // backPtr_a[t][ltrIdx]
135141
}
142+
delete[] backPtr_a;
136143
}
137144

138145
std::tuple<Tensor, Tensor> compute(

0 commit comments

Comments
 (0)