Skip to content

Commit 9beb34a

Browse files
committed
Fix multidimensional indexing bug
1 parent b733629 commit 9beb34a

File tree

1 file changed

+22
-13
lines changed

1 file changed

+22
-13
lines changed

src/libtorchaudio/forced_align/cpu/compute.cpp

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,27 @@ namespace torchaudio {
1414
namespace alignment {
1515
namespace cpu {
1616

17+
// Compute strides for row-major indexing
18+
template<unsigned int k>
19+
void reverse_cumprod(int64_t (&strides)[k]) {
20+
// Convert dimensions to strides: stride[i] = product of dimensions [i+1..k-1]
21+
for (int i = k - 2; i >= 0; i--) {
22+
strides[i] = strides[i] * strides[i + 1];
23+
}
24+
}
25+
1726
template<unsigned int k, typename T>
1827
class Accessor {
19-
int64_t shape[k];
28+
int64_t strides[k-1];
2029
T *data;
2130

2231
public:
2332
Accessor(const torch::Tensor& tensor) {
2433
data = tensor.data_ptr<T>();
25-
for (int i = 0; i < k; i++) {
26-
shape[i] = tensor.size(i);
34+
for (int i = 1; i < k; i++) {
35+
strides[i-1] = tensor.size(i);
2736
}
37+
reverse_cumprod<k-1>(strides);
2838
}
2939

3040
T index(...) {
@@ -35,43 +45,42 @@ class Accessor {
3545
if (i == k - 1)
3646
ix += va_arg(args, int);
3747
else
38-
ix += shape[i+1] * va_arg(args, int);
48+
ix += strides[i] * va_arg(args, int);
3949
}
4050
va_end(args);
4151
return data[ix];
4252
}
4353
};
4454

45-
4655
template<unsigned int k, typename T>
4756
class MutAccessor {
48-
int64_t shape[k];
57+
int64_t strides[k-1];
4958
T *data;
5059

5160
public:
52-
MutAccessor(torch::Tensor& tensor) {
61+
MutAccessor(torch::Tensor& tensor) {
5362
data = tensor.data_ptr<T>();
54-
for (int i = 0; i < k; i++) {
55-
shape[i] = tensor.size(i);
63+
for (int i = 1; i < k; i++) {
64+
strides[i-1] = tensor.size(i);
5665
}
66+
reverse_cumprod<k-1>(strides);
5767
}
5868

59-
void set_index(T value,...) {
69+
void set_index(T value, ...) {
6070
va_list args;
61-
va_start(args, k);
71+
va_start(args, value);
6272
int64_t ix = 0;
6373
for (int i = 0; i < k; i++) {
6474
if (i == k - 1)
6575
ix += va_arg(args, int);
6676
else
67-
ix += shape[i+1] * va_arg(args, int);
77+
ix += strides[i] * va_arg(args, int);
6878
}
6979
va_end(args);
7080
data[ix] = value;
7181
}
7282
};
7383

74-
7584
// Inspired from
7685
// https://github.com/flashlight/sequence/blob/main/flashlight/lib/sequence/criterion/cpu/ConnectionistTemporalClassificationCriterion.cpp
7786
template <typename scalar_t, at::ScalarType target_scalar_type>

0 commit comments

Comments
 (0)