Skip to content

Commit b733629

Browse files
committed
Add MutAccessor
1 parent 4039399 commit b733629

File tree

1 file changed

+31
-15
lines changed

1 file changed

+31
-15
lines changed

src/libtorchaudio/forced_align/cpu/compute.cpp

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -40,22 +40,38 @@ class Accessor {
4040
va_end(args);
4141
return data[ix];
4242
}
43+
};
44+
45+
46+
template<unsigned int k, typename T>
47+
class MutAccessor {
48+
int64_t shape[k];
49+
T *data;
50+
51+
public:
52+
MutAccessor(torch::Tensor& tensor) {
53+
data = tensor.data_ptr<T>();
54+
for (int i = 0; i < k; i++) {
55+
shape[i] = tensor.size(i);
56+
}
57+
}
4358

44-
// void set_index(T val,...) {
45-
// va_list args;
46-
// va_start(args, k);
47-
// int64_t ix = 0;
48-
// for (int i = 0; i < k; i++) {
49-
// if (i == k - 1)
50-
// ix += va_arg(args, int);
51-
// else
52-
// ix += shape[i+1] * va_arg(args, int);
53-
// }
54-
// va_end(args);
55-
// data[ix] = val;
56-
// }
59+
void set_index(T value,...) {
60+
va_list args;
61+
va_start(args, k);
62+
int64_t ix = 0;
63+
for (int i = 0; i < k; i++) {
64+
if (i == k - 1)
65+
ix += va_arg(args, int);
66+
else
67+
ix += shape[i+1] * va_arg(args, int);
68+
}
69+
va_end(args);
70+
data[ix] = value;
71+
}
5772
};
5873

74+
5975
// Inspired from
6076
// https://github.com/flashlight/sequence/blob/main/flashlight/lib/sequence/criterion/cpu/ConnectionistTemporalClassificationCriterion.cpp
6177
template <typename scalar_t, at::ScalarType target_scalar_type>
@@ -86,7 +102,7 @@ void forced_align_impl(
86102

87103
auto logProbs_a = Accessor<3, scalar_t>(logProbs);
88104
auto targets_a = Accessor<2, target_t>(targets);
89-
auto paths_a = paths.accessor<target_t, 2>();
105+
auto paths_a = MutAccessor<2, target_t>(paths);
90106
auto R = 0;
91107
for (auto i = 1; i < L; i++) {
92108
if (targets_a.index(batchIndex, i) == targets_a.index(batchIndex, i - 1)) {
@@ -171,7 +187,7 @@ void forced_align_impl(
171187
// path stores the token index for each time step after force alignment.
172188
for (auto t = T - 1; t > -1; t--) {
173189
auto lbl_idx = ltrIdx % 2 == 0 ? blank : targets_a.index(batchIndex, ltrIdx / 2);
174-
paths_a[batchIndex][t] = lbl_idx;
190+
paths_a.set_index(lbl_idx, batchIndex, t);
175191
ltrIdx -= backPtr_a[t * S + ltrIdx];
176192
}
177193
}

0 commit comments

Comments
 (0)