@@ -14,17 +14,27 @@ namespace torchaudio {
14
14
namespace alignment {
15
15
namespace cpu {
16
16
17
+ // Compute strides for row-major indexing
18
+ template <unsigned int k>
19
+ void reverse_cumprod (int64_t (&strides)[k]) {
20
+ // Convert dimensions to strides: stride[i] = product of dimensions [i+1..k-1]
21
+ for (int i = k - 2 ; i >= 0 ; i--) {
22
+ strides[i] = strides[i] * strides[i + 1 ];
23
+ }
24
+ }
25
+
17
26
template <unsigned int k, typename T>
18
27
class Accessor {
19
- int64_t shape[k ];
28
+ int64_t strides[k- 1 ];
20
29
T *data;
21
30
22
31
public:
23
32
Accessor (const torch::Tensor& tensor) {
24
33
data = tensor.data_ptr <T>();
25
- for (int i = 0 ; i < k; i++) {
26
- shape[i ] = tensor.size (i);
34
+ for (int i = 1 ; i < k; i++) {
35
+ strides[i- 1 ] = tensor.size (i);
27
36
}
37
+ reverse_cumprod<k-1 >(strides);
28
38
}
29
39
30
40
T index (...) {
@@ -35,43 +45,42 @@ class Accessor {
35
45
if (i == k - 1 )
36
46
ix += va_arg (args, int );
37
47
else
38
- ix += shape[i+ 1 ] * va_arg (args, int );
48
+ ix += strides[i ] * va_arg (args, int );
39
49
}
40
50
va_end (args);
41
51
return data[ix];
42
52
}
43
53
};
44
54
45
-
46
55
template <unsigned int k, typename T>
47
56
class MutAccessor {
48
- int64_t shape[k ];
57
+ int64_t strides[k- 1 ];
49
58
T *data;
50
59
51
60
public:
52
- MutAccessor (torch::Tensor& tensor) {
61
+ MutAccessor (torch::Tensor& tensor) {
53
62
data = tensor.data_ptr <T>();
54
- for (int i = 0 ; i < k; i++) {
55
- shape[i ] = tensor.size (i);
63
+ for (int i = 1 ; i < k; i++) {
64
+ strides[i- 1 ] = tensor.size (i);
56
65
}
66
+ reverse_cumprod<k-1 >(strides);
57
67
}
58
68
59
- void set_index (T value,...) {
69
+ void set_index (T value, ...) {
60
70
va_list args;
61
- va_start (args, k );
71
+ va_start (args, value );
62
72
int64_t ix = 0 ;
63
73
for (int i = 0 ; i < k; i++) {
64
74
if (i == k - 1 )
65
75
ix += va_arg (args, int );
66
76
else
67
- ix += shape[i+ 1 ] * va_arg (args, int );
77
+ ix += strides[i ] * va_arg (args, int );
68
78
}
69
79
va_end (args);
70
80
data[ix] = value;
71
81
}
72
82
};
73
83
74
-
75
84
// Inspired from
76
85
// https://github.com/flashlight/sequence/blob/main/flashlight/lib/sequence/criterion/cpu/ConnectionistTemporalClassificationCriterion.cpp
77
86
template <typename scalar_t , at::ScalarType target_scalar_type>
0 commit comments