66#include < torch/csrc/inductor/aoti_torch/c/shim.h>
77#include < torch/csrc/inductor/aoti_torch/utils.h>
88#include < cstdarg>
9+ #include < type_traits>
910
1011
1112using namespace std ;
@@ -14,14 +15,16 @@ namespace torchaudio {
1415namespace alignment {
1516namespace cpu {
1617
17- template <unsigned int k, typename T>
18+ template <unsigned int k, typename T, bool IsConst = true >
1819class Accessor {
1920 int64_t strides[k];
2021 T *data;
2122
2223public:
23- Accessor (const torch::Tensor& tensor) {
24- data = tensor.data_ptr <T>();
24+ using tensor_type = typename std::conditional<IsConst, const torch::Tensor&, torch::Tensor&>::type;
25+
26+ Accessor (tensor_type tensor) {
27+ data = tensor.template data_ptr <T>();
2528 for (int i = 0 ; i < k; i++) {
2629 strides[i] = tensor.stride (i);
2730 }
@@ -37,22 +40,9 @@ class Accessor {
3740 va_end (args);
3841 return data[ix];
3942 }
40- };
41-
42- template <unsigned int k, typename T>
43- class MutAccessor {
44- int64_t strides[k];
45- T *data;
46-
47- public:
48- MutAccessor (torch::Tensor& tensor) {
49- data = tensor.data_ptr <T>();
50- for (int i = 0 ; i < k; i++) {
51- strides[i] = tensor.stride (i);
52- }
53- }
5443
55- void set_index (T value, ...) {
44+ template <bool C = IsConst>
45+ typename std::enable_if<!C, void >::type set_index (T value, ...) {
5646 va_list args;
5747 va_start (args, value);
5848 int64_t ix = 0 ;
@@ -92,9 +82,9 @@ void forced_align_impl(
9282 backPtr_a[i] = -1 ;
9383 }
9484
95- auto logProbs_a = Accessor<3 , scalar_t >(logProbs);
96- auto targets_a = Accessor<2 , target_t >(targets);
97- auto paths_a = MutAccessor <2 , target_t >(paths);
85+ auto logProbs_a = Accessor<3 , scalar_t , true >(logProbs);
86+ auto targets_a = Accessor<2 , target_t , true >(targets);
87+ auto paths_a = Accessor <2 , target_t , false >(paths);
9888 auto R = 0 ;
9989 for (auto i = 1 ; i < L; i++) {
10090 if (targets_a.index (batchIndex, i) == targets_a.index (batchIndex, i - 1 )) {
0 commit comments