@@ -14,37 +14,24 @@ namespace torchaudio {
1414namespace alignment {
1515namespace cpu {
1616
17- // Compute strides for row-major indexing
18- template <unsigned int k>
19- void reverse_cumprod (int64_t (&strides)[k]) {
20- // Convert dimensions to strides: stride[i] = product of dimensions [i+1..k-1]
21- for (int i = k - 2 ; i >= 0 ; i--) {
22- strides[i] = strides[i] * strides[i + 1 ];
23- }
24- }
25-
2617template <unsigned int k, typename T>
2718class Accessor {
28- int64_t strides[k- 1 ];
19+ int64_t strides[k];
2920 T *data;
3021
3122public:
3223 Accessor (const torch::Tensor& tensor) {
3324 data = tensor.data_ptr <T>();
34- for (int i = 1 ; i < k; i++) {
35- strides[i- 1 ] = tensor.size (i);
25+ for (int i = 0 ; i < k; i++) {
26+ strides[i] = tensor.stride (i);
3627 }
37- reverse_cumprod<k-1 >(strides);
3828 }
3929
4030 T index (...) {
4131 va_list args;
4232 va_start (args, k);
4333 int64_t ix = 0 ;
4434 for (int i = 0 ; i < k; i++) {
45- if (i == k - 1 )
46- ix += va_arg (args, int );
47- else
4835 ix += strides[i] * va_arg (args, int );
4936 }
5037 va_end (args);
@@ -54,26 +41,22 @@ class Accessor {
5441
5542template <unsigned int k, typename T>
5643class MutAccessor {
57- int64_t strides[k- 1 ];
44+ int64_t strides[k];
5845 T *data;
5946
6047public:
6148 MutAccessor (torch::Tensor& tensor) {
6249 data = tensor.data_ptr <T>();
63- for (int i = 1 ; i < k; i++) {
64- strides[i- 1 ] = tensor.size (i);
50+ for (int i = 0 ; i < k; i++) {
51+ strides[i] = tensor.stride (i);
6552 }
66- reverse_cumprod<k-1 >(strides);
6753 }
6854
6955 void set_index (T value, ...) {
7056 va_list args;
7157 va_start (args, value);
7258 int64_t ix = 0 ;
7359 for (int i = 0 ; i < k; i++) {
74- if (i == k - 1 )
75- ix += va_arg (args, int );
76- else
7760 ix += strides[i] * va_arg (args, int );
7861 }
7962 va_end (args);
0 commit comments