6
6
#include < torch/csrc/inductor/aoti_torch/c/shim.h>
7
7
#include < torch/csrc/inductor/aoti_torch/utils.h>
8
8
#include < cstdarg>
9
+ #include < type_traits>
9
10
10
11
11
12
using namespace std ;
@@ -14,14 +15,16 @@ namespace torchaudio {
14
15
namespace alignment {
15
16
namespace cpu {
16
17
17
- template <unsigned int k, typename T>
18
+ template <unsigned int k, typename T, bool IsConst = true >
18
19
class Accessor {
19
20
int64_t strides[k];
20
21
T *data;
21
22
22
23
public:
23
- Accessor (const torch::Tensor& tensor) {
24
- data = tensor.data_ptr <T>();
24
+ using tensor_type = typename std::conditional<IsConst, const torch::Tensor&, torch::Tensor&>::type;
25
+
26
+ Accessor (tensor_type tensor) {
27
+ data = tensor.template data_ptr <T>();
25
28
for (int i = 0 ; i < k; i++) {
26
29
strides[i] = tensor.stride (i);
27
30
}
@@ -37,22 +40,9 @@ class Accessor {
37
40
va_end (args);
38
41
return data[ix];
39
42
}
40
- };
41
-
42
- template <unsigned int k, typename T>
43
- class MutAccessor {
44
- int64_t strides[k];
45
- T *data;
46
-
47
- public:
48
- MutAccessor (torch::Tensor& tensor) {
49
- data = tensor.data_ptr <T>();
50
- for (int i = 0 ; i < k; i++) {
51
- strides[i] = tensor.stride (i);
52
- }
53
- }
54
43
55
- void set_index (T value, ...) {
44
+ template <bool C = IsConst>
45
+ typename std::enable_if<!C, void >::type set_index (T value, ...) {
56
46
va_list args;
57
47
va_start (args, value);
58
48
int64_t ix = 0 ;
@@ -92,9 +82,9 @@ void forced_align_impl(
92
82
backPtr_a[i] = -1 ;
93
83
}
94
84
95
- auto logProbs_a = Accessor<3 , scalar_t >(logProbs);
96
- auto targets_a = Accessor<2 , target_t >(targets);
97
- auto paths_a = MutAccessor <2 , target_t >(paths);
85
+ auto logProbs_a = Accessor<3 , scalar_t , true >(logProbs);
86
+ auto targets_a = Accessor<2 , target_t , true >(targets);
87
+ auto paths_a = Accessor <2 , target_t , false >(paths);
98
88
auto R = 0 ;
99
89
for (auto i = 1 ; i < L; i++) {
100
90
if (targets_a.index (batchIndex, i) == targets_a.index (batchIndex, i - 1 )) {
0 commit comments