Skip to content

Commit 11d1e21

Browse files
committed
Use strides rather than computing standard strides from dims
1 parent 9beb34a commit 11d1e21

File tree

1 file changed

+6
-23
lines changed

1 file changed

+6
-23
lines changed

src/libtorchaudio/forced_align/cpu/compute.cpp

Lines changed: 6 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -14,37 +14,24 @@ namespace torchaudio {
1414
namespace alignment {
1515
namespace cpu {
1616

17-
// Compute strides for row-major indexing
18-
template<unsigned int k>
19-
void reverse_cumprod(int64_t (&strides)[k]) {
20-
// Convert dimensions to strides: stride[i] = product of dimensions [i+1..k-1]
21-
for (int i = k - 2; i >= 0; i--) {
22-
strides[i] = strides[i] * strides[i + 1];
23-
}
24-
}
25-
2617
template<unsigned int k, typename T>
2718
class Accessor {
28-
int64_t strides[k-1];
19+
int64_t strides[k];
2920
T *data;
3021

3122
public:
3223
Accessor(const torch::Tensor& tensor) {
3324
data = tensor.data_ptr<T>();
34-
for (int i = 1; i < k; i++) {
35-
strides[i-1] = tensor.size(i);
25+
for (int i = 0; i < k; i++) {
26+
strides[i] = tensor.stride(i);
3627
}
37-
reverse_cumprod<k-1>(strides);
3828
}
3929

4030
T index(...) {
4131
va_list args;
4232
va_start(args, k);
4333
int64_t ix = 0;
4434
for (int i = 0; i < k; i++) {
45-
if (i == k - 1)
46-
ix += va_arg(args, int);
47-
else
4835
ix += strides[i] * va_arg(args, int);
4936
}
5037
va_end(args);
@@ -54,26 +41,22 @@ class Accessor {
5441

5542
template<unsigned int k, typename T>
5643
class MutAccessor {
57-
int64_t strides[k-1];
44+
int64_t strides[k];
5845
T *data;
5946

6047
public:
6148
MutAccessor(torch::Tensor& tensor) {
6249
data = tensor.data_ptr<T>();
63-
for (int i = 1; i < k; i++) {
64-
strides[i-1] = tensor.size(i);
50+
for (int i = 0; i < k; i++) {
51+
strides[i] = tensor.stride(i);
6552
}
66-
reverse_cumprod<k-1>(strides);
6753
}
6854

6955
void set_index(T value, ...) {
7056
va_list args;
7157
va_start(args, value);
7258
int64_t ix = 0;
7359
for (int i = 0; i < k; i++) {
74-
if (i == k - 1)
75-
ix += va_arg(args, int);
76-
else
7760
ix += strides[i] * va_arg(args, int);
7861
}
7962
va_end(args);

0 commit comments

Comments
 (0)