Skip to content

Commit 4039399

Browse files
committed
Create Accessor class
1 parent e70113c commit 4039399

File tree

1 file changed

+60
-14
lines changed

1 file changed

+60
-14
lines changed

src/libtorchaudio/forced_align/cpu/compute.cpp

Lines changed: 60 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,58 @@
44
#include <torch/csrc/stable/tensor.h>
55
#include <torch/csrc/stable/ops.h>
66
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
7+
#include <torch/csrc/inductor/aoti_torch/utils.h>
8+
#include <cstdarg>
9+
710

811
using namespace std;
912

1013
namespace torchaudio {
1114
namespace alignment {
1215
namespace cpu {
16+
17+
template<unsigned int k, typename T>
18+
class Accessor {
19+
int64_t shape[k];
20+
T *data;
21+
22+
public:
23+
Accessor(const torch::Tensor& tensor) {
24+
data = tensor.data_ptr<T>();
25+
for (int i = 0; i < k; i++) {
26+
shape[i] = tensor.size(i);
27+
}
28+
}
29+
30+
T index(...) {
31+
va_list args;
32+
va_start(args, k);
33+
int64_t ix = 0;
34+
for (int i = 0; i < k; i++) {
35+
if (i == k - 1)
36+
ix += va_arg(args, int);
37+
else
38+
ix += shape[i+1] * va_arg(args, int);
39+
}
40+
va_end(args);
41+
return data[ix];
42+
}
43+
44+
// void set_index(T val,...) {
45+
// va_list args;
46+
// va_start(args, k);
47+
// int64_t ix = 0;
48+
// for (int i = 0; i < k; i++) {
49+
// if (i == k - 1)
50+
// ix += va_arg(args, int);
51+
// else
52+
// ix += shape[i+1] * va_arg(args, int);
53+
// }
54+
// va_end(args);
55+
// data[ix] = val;
56+
// }
57+
};
58+
1359
// Inspired from
1460
// https://github.com/flashlight/sequence/blob/main/flashlight/lib/sequence/criterion/cpu/ConnectionistTemporalClassificationCriterion.cpp
1561
template <typename scalar_t, at::ScalarType target_scalar_type>
@@ -38,12 +84,12 @@ void forced_align_impl(
3884
backPtr_a[i] = -1;
3985
}
4086

41-
auto logProbs_a = logProbs.accessor<scalar_t, 3>();
42-
auto targets_a = targets.accessor<target_t, 2>();
87+
auto logProbs_a = Accessor<3, scalar_t>(logProbs);
88+
auto targets_a = Accessor<2, target_t>(targets);
4389
auto paths_a = paths.accessor<target_t, 2>();
4490
auto R = 0;
4591
for (auto i = 1; i < L; i++) {
46-
if (targets_a[batchIndex][i] == targets_a[batchIndex][i - 1]) {
92+
if (targets_a.index(batchIndex, i) == targets_a.index(batchIndex, i - 1)) {
4793
++R;
4894
}
4995
}
@@ -58,22 +104,22 @@ void forced_align_impl(
58104
auto start = T - (L + R) > 0 ? 0 : 1;
59105
auto end = (S == 1) ? 1 : 2;
60106
for (auto i = start; i < end; i++) {
61-
auto labelIdx = (i % 2 == 0) ? blank : targets_a[batchIndex][i / 2];
62-
alphas_a[i][0] = logProbs_a[batchIndex][0][labelIdx];
107+
auto labelIdx = (i % 2 == 0) ? blank : targets_a.index(batchIndex, i / 2);
108+
alphas_a[i][0] = logProbs_a.index(batchIndex,0,labelIdx);
63109
}
64110
for (auto t = 1; t < T; t++) {
65111
if (T - t <= L + R) {
66112
if ((start % 2 == 1) &&
67-
targets_a[batchIndex][start / 2] !=
68-
targets_a[batchIndex][start / 2 + 1]) {
113+
targets_a.index(batchIndex, start / 2) !=
114+
targets_a.index(batchIndex, start / 2 + 1)) {
69115
start = start + 1;
70116
}
71117
start = start + 1;
72118
}
73119
if (t <= L + R) {
74120
if (end % 2 == 0 && end < 2 * L &&
75-
targets_a[batchIndex][end / 2 - 1] !=
76-
targets_a[batchIndex][end / 2]) {
121+
targets_a.index(batchIndex, end / 2 - 1) !=
122+
targets_a.index(batchIndex, end / 2)) {
77123
end = end + 1;
78124
}
79125
end = end + 1;
@@ -86,7 +132,7 @@ void forced_align_impl(
86132
}
87133
if (start == 0) {
88134
alphas_a[0][curIdxOffset] =
89-
alphas_a[0][prevIdxOffset] + logProbs_a[batchIndex][t][blank];
135+
alphas_a[0][prevIdxOffset] + logProbs_a.index(batchIndex, t, blank);
90136
backPtr_a[S * t] = 0;
91137
startloop += 1;
92138
}
@@ -96,14 +142,14 @@ void forced_align_impl(
96142
auto x1 = alphas_a[i - 1][prevIdxOffset];
97143
auto x2 = -std::numeric_limits<scalar_t>::infinity();
98144

99-
auto labelIdx = (i % 2 == 0) ? blank : targets_a[batchIndex][i / 2];
145+
auto labelIdx = (i % 2 == 0) ? blank : targets_a.index(batchIndex, i / 2);
100146

101147
// In CTC, the optimal path may optionally chose to skip a blank label.
102148
// x2 represents skipping a letter, and can only happen if we're not
103149
// currently on a blank_label, and we're not on a repeat letter
104150
// (i != 1) just ensures we don't access targets[i - 2] if its i < 2
105151
if (i % 2 != 0 && i != 1 &&
106-
targets_a[batchIndex][i / 2] != targets_a[batchIndex][i / 2 - 1]) {
152+
targets_a.index(batchIndex, i / 2) != targets_a.index(batchIndex, i / 2 - 1)) {
107153
x2 = alphas_a[i - 2][prevIdxOffset];
108154
}
109155
scalar_t result = 0.0;
@@ -117,14 +163,14 @@ void forced_align_impl(
117163
result = x0;
118164
backPtr_a[t * S + i] = 0;
119165
}
120-
alphas_a[i][curIdxOffset] = result + logProbs_a[batchIndex][t][labelIdx];
166+
alphas_a[i][curIdxOffset] = result + logProbs_a.index(batchIndex, t, labelIdx);
121167
}
122168
}
123169
auto idx1 = (T - 1) % 2;
124170
auto ltrIdx = alphas_a[S - 1][idx1] > alphas_a[S - 2][idx1] ? S - 1 : S - 2;
125171
// path stores the token index for each time step after force alignment.
126172
for (auto t = T - 1; t > -1; t--) {
127-
auto lbl_idx = ltrIdx % 2 == 0 ? blank : targets_a[batchIndex][ltrIdx / 2];
173+
auto lbl_idx = ltrIdx % 2 == 0 ? blank : targets_a.index(batchIndex, ltrIdx / 2);
128174
paths_a[batchIndex][t] = lbl_idx;
129175
ltrIdx -= backPtr_a[t * S + ltrIdx];
130176
}

0 commit comments

Comments
 (0)