@@ -14,37 +14,24 @@ namespace torchaudio {
14
14
namespace alignment {
15
15
namespace cpu {
16
16
17
- // Compute strides for row-major indexing
18
- template <unsigned int k>
19
- void reverse_cumprod (int64_t (&strides)[k]) {
20
- // Convert dimensions to strides: stride[i] = product of dimensions [i+1..k-1]
21
- for (int i = k - 2 ; i >= 0 ; i--) {
22
- strides[i] = strides[i] * strides[i + 1 ];
23
- }
24
- }
25
-
26
17
template <unsigned int k, typename T>
27
18
class Accessor {
28
- int64_t strides[k- 1 ];
19
+ int64_t strides[k];
29
20
T *data;
30
21
31
22
public:
32
23
Accessor (const torch::Tensor& tensor) {
33
24
data = tensor.data_ptr <T>();
34
- for (int i = 1 ; i < k; i++) {
35
- strides[i- 1 ] = tensor.size (i);
25
+ for (int i = 0 ; i < k; i++) {
26
+ strides[i] = tensor.stride (i);
36
27
}
37
- reverse_cumprod<k-1 >(strides);
38
28
}
39
29
40
30
T index (...) {
41
31
va_list args;
42
32
va_start (args, k);
43
33
int64_t ix = 0 ;
44
34
for (int i = 0 ; i < k; i++) {
45
- if (i == k - 1 )
46
- ix += va_arg (args, int );
47
- else
48
35
ix += strides[i] * va_arg (args, int );
49
36
}
50
37
va_end (args);
@@ -54,26 +41,22 @@ class Accessor {
54
41
55
42
template <unsigned int k, typename T>
56
43
class MutAccessor {
57
- int64_t strides[k- 1 ];
44
+ int64_t strides[k];
58
45
T *data;
59
46
60
47
public:
61
48
MutAccessor (torch::Tensor& tensor) {
62
49
data = tensor.data_ptr <T>();
63
- for (int i = 1 ; i < k; i++) {
64
- strides[i- 1 ] = tensor.size (i);
50
+ for (int i = 0 ; i < k; i++) {
51
+ strides[i] = tensor.stride (i);
65
52
}
66
- reverse_cumprod<k-1 >(strides);
67
53
}
68
54
69
55
void set_index (T value, ...) {
70
56
va_list args;
71
57
va_start (args, value);
72
58
int64_t ix = 0 ;
73
59
for (int i = 0 ; i < k; i++) {
74
- if (i == k - 1 )
75
- ix += va_arg (args, int );
76
- else
77
60
ix += strides[i] * va_arg (args, int );
78
61
}
79
62
va_end (args);
0 commit comments