@@ -14,17 +14,27 @@ namespace torchaudio {
1414namespace alignment {
1515namespace cpu {
1616
17+ // Compute strides for row-major indexing
18+ template <unsigned int k>
19+ void reverse_cumprod (int64_t (&strides)[k]) {
20+ // Convert dimensions to strides: stride[i] = product of dimensions [i+1..k-1]
21+ for (int i = k - 2 ; i >= 0 ; i--) {
22+ strides[i] = strides[i] * strides[i + 1 ];
23+ }
24+ }
25+
1726template <unsigned int k, typename T>
1827class Accessor {
19- int64_t shape[k ];
28+ int64_t strides[k- 1 ];
2029 T *data;
2130
2231public:
2332 Accessor (const torch::Tensor& tensor) {
2433 data = tensor.data_ptr <T>();
25- for (int i = 0 ; i < k; i++) {
26- shape[i ] = tensor.size (i);
34+ for (int i = 1 ; i < k; i++) {
35+ strides[i- 1 ] = tensor.size (i);
2736 }
37+ reverse_cumprod<k-1 >(strides);
2838 }
2939
3040 T index (...) {
@@ -35,43 +45,42 @@ class Accessor {
3545 if (i == k - 1 )
3646 ix += va_arg (args, int );
3747 else
38- ix += shape[i+ 1 ] * va_arg (args, int );
48+ ix += strides[i ] * va_arg (args, int );
3949 }
4050 va_end (args);
4151 return data[ix];
4252 }
4353};
4454
45-
4655template <unsigned int k, typename T>
4756class MutAccessor {
48- int64_t shape[k ];
57+ int64_t strides[k- 1 ];
4958 T *data;
5059
5160public:
52- MutAccessor (torch::Tensor& tensor) {
61+ MutAccessor (torch::Tensor& tensor) {
5362 data = tensor.data_ptr <T>();
54- for (int i = 0 ; i < k; i++) {
55- shape[i ] = tensor.size (i);
63+ for (int i = 1 ; i < k; i++) {
64+ strides[i- 1 ] = tensor.size (i);
5665 }
66+ reverse_cumprod<k-1 >(strides);
5767 }
5868
59- void set_index (T value,...) {
69+ void set_index (T value, ...) {
6070 va_list args;
61- va_start (args, k );
71+ va_start (args, value );
6272 int64_t ix = 0 ;
6373 for (int i = 0 ; i < k; i++) {
6474 if (i == k - 1 )
6575 ix += va_arg (args, int );
6676 else
67- ix += shape[i+ 1 ] * va_arg (args, int );
77+ ix += strides[i ] * va_arg (args, int );
6878 }
6979 va_end (args);
7080 data[ix] = value;
7181 }
7282};
7383
74-
7584// Inspired from
7685// https://github.com/flashlight/sequence/blob/main/flashlight/lib/sequence/criterion/cpu/ConnectionistTemporalClassificationCriterion.cpp
7786template <typename scalar_t , at::ScalarType target_scalar_type>
0 commit comments