Skip to content

Commit e500f0c

Browse files
[STABLE FORCED_ALIGN ABI PORT] Make alphas_a standard C array in forced_align (#4020)
* Make alphas_a standard C array * Add requested comment about scalar_t * Free alphas_a array * Use 1d indexing in original layout for alphas_a --------- Co-authored-by: Sam Anklesaria <[email protected]>
1 parent c819dbe commit e500f0c

File tree

1 file changed

+21
-16
lines changed

1 file changed

+21
-16
lines changed

src/libtorchaudio/forced_align/cpu/compute.cpp

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
#include <torch/script.h>
22
#include <torch/torch.h>
3+
#include <torch/csrc/stable/library.h>
4+
#include <torch/csrc/stable/tensor.h>
5+
#include <torch/csrc/stable/ops.h>
6+
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
37

48
using namespace std;
59

@@ -22,17 +26,16 @@ void forced_align_impl(
2226
const auto T = logProbs.size(1);
2327
const auto L = targets.size(1);
2428
const auto S = 2 * L + 1;
25-
torch::Tensor alphas = torch::empty(
26-
{2, S},
27-
torch::TensorOptions()
28-
.device(logProbs.device())
29-
.dtype(logProbs.dtype()))
30-
.fill_(kNegInfinity);
29+
30+
auto alphas_a = new scalar_t[2 * S]; // scalar_t is just logProbs.dtype()
31+
for (int i = 0; i < 2 * S; i++) {
32+
alphas_a[i] = kNegInfinity;
33+
}
34+
3135
torch::Tensor backPtr = torch::empty({T, S}, torch::kInt8).fill_(-1);
3236
auto logProbs_a = logProbs.accessor<scalar_t, 3>();
3337
auto targets_a = targets.accessor<target_t, 2>();
3438
auto paths_a = paths.accessor<target_t, 2>();
35-
auto alphas_a = alphas.accessor<scalar_t, 2>();
3639
auto backPtr_a = backPtr.accessor<int8_t, 2>();
3740
auto R = 0;
3841
for (auto i = 1; i < L; i++) {
@@ -52,7 +55,7 @@ void forced_align_impl(
5255
auto end = (S == 1) ? 1 : 2;
5356
for (auto i = start; i < end; i++) {
5457
auto labelIdx = (i % 2 == 0) ? blank : targets_a[batchIndex][i / 2];
55-
alphas_a[0][i] = logProbs_a[batchIndex][0][labelIdx];
58+
alphas_a[i] = logProbs_a[batchIndex][0][labelIdx]; // alphas_a[0, i]
5659
}
5760
for (auto t = 1; t < T; t++) {
5861
if (T - t <= L + R) {
@@ -75,18 +78,18 @@ void forced_align_impl(
7578
auto curIdxOffset = t % 2;
7679
auto prevIdxOffset = (t - 1) % 2;
7780
for (auto j = 0; j < S; ++j) {
78-
alphas_a[curIdxOffset][j] = -std::numeric_limits<scalar_t>::infinity();
81+
alphas_a[curIdxOffset * S + j] = -std::numeric_limits<scalar_t>::infinity(); // alphas_a[curIdxOffset][j]
7982
}
8083
if (start == 0) {
81-
alphas_a[curIdxOffset][0] =
82-
alphas_a[prevIdxOffset][0] + logProbs_a[batchIndex][t][blank];
84+
alphas_a[curIdxOffset * S] =
85+
alphas_a[prevIdxOffset * S] + logProbs_a[batchIndex][t][blank];
8386
backPtr_a[t][0] = 0;
8487
startloop += 1;
8588
}
8689

8790
for (auto i = startloop; i < end; i++) {
88-
auto x0 = alphas_a[prevIdxOffset][i];
89-
auto x1 = alphas_a[prevIdxOffset][i - 1];
91+
auto x0 = alphas_a[prevIdxOffset * S + i]; // alphas_a[prevIdxOffset][i];
92+
auto x1 = alphas_a[prevIdxOffset * S + i - 1]; // alphas_a[prevIdxOffset][i - 1];
9093
auto x2 = -std::numeric_limits<scalar_t>::infinity();
9194

9295
auto labelIdx = (i % 2 == 0) ? blank : targets_a[batchIndex][i / 2];
@@ -97,7 +100,7 @@ void forced_align_impl(
97100
// (i != 1) just ensures we don't access targets[i - 2] if its i < 2
98101
if (i % 2 != 0 && i != 1 &&
99102
targets_a[batchIndex][i / 2] != targets_a[batchIndex][i / 2 - 1]) {
100-
x2 = alphas_a[prevIdxOffset][i - 2];
103+
x2 = alphas_a[prevIdxOffset * S + i - 2]; // alphas_a[prevIdxOffset][i - 2];
101104
}
102105
scalar_t result = 0.0;
103106
if (x2 > x1 && x2 > x0) {
@@ -110,11 +113,13 @@ void forced_align_impl(
110113
result = x0;
111114
backPtr_a[t][i] = 0;
112115
}
113-
alphas_a[curIdxOffset][i] = result + logProbs_a[batchIndex][t][labelIdx];
116+
alphas_a[curIdxOffset * S + i] = result + logProbs_a[batchIndex][t][labelIdx]; // alphas_a[curIdxOffset][i]
114117
}
115118
}
116119
auto idx1 = (T - 1) % 2;
117-
auto ltrIdx = alphas_a[idx1][S - 1] > alphas_a[idx1][S - 2] ? S - 1 : S - 2;
120+
auto ltrIdx = alphas_a[S * idx1 + S - 1] >
121+
alphas_a[S * idx1 + S - 2] ? S - 1 : S - 2; // alphas_a[idx1][S - 1], alphas_a[idx1][S - 2]
122+
delete[] alphas_a;
118123
// path stores the token index for each time step after force alignment.
119124
for (auto t = T - 1; t > -1; t--) {
120125
auto lbl_idx = ltrIdx % 2 == 0 ? blank : targets_a[batchIndex][ltrIdx / 2];

0 commit comments

Comments
 (0)