Skip to content

Commit b47c053

Browse files
committed
Merge Accessor and MutAccessor
1 parent 11d1e21 commit b47c053

File tree

1 file changed

+11
-21
lines changed

1 file changed

+11
-21
lines changed

src/libtorchaudio/forced_align/cpu/compute.cpp

Lines changed: 11 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
77
#include <torch/csrc/inductor/aoti_torch/utils.h>
88
#include <cstdarg>
9+
#include <type_traits>
910

1011

1112
using namespace std;
@@ -14,14 +15,16 @@ namespace torchaudio {
1415
namespace alignment {
1516
namespace cpu {
1617

17-
template<unsigned int k, typename T>
18+
template<unsigned int k, typename T, bool IsConst = true>
1819
class Accessor {
1920
int64_t strides[k];
2021
T *data;
2122

2223
public:
23-
Accessor(const torch::Tensor& tensor) {
24-
data = tensor.data_ptr<T>();
24+
using tensor_type = typename std::conditional<IsConst, const torch::Tensor&, torch::Tensor&>::type;
25+
26+
Accessor(tensor_type tensor) {
27+
data = tensor.template data_ptr<T>();
2528
for (int i = 0; i < k; i++) {
2629
strides[i] = tensor.stride(i);
2730
}
@@ -37,22 +40,9 @@ class Accessor {
3740
va_end(args);
3841
return data[ix];
3942
}
40-
};
41-
42-
template<unsigned int k, typename T>
43-
class MutAccessor {
44-
int64_t strides[k];
45-
T *data;
46-
47-
public:
48-
MutAccessor(torch::Tensor& tensor) {
49-
data = tensor.data_ptr<T>();
50-
for (int i = 0; i < k; i++) {
51-
strides[i] = tensor.stride(i);
52-
}
53-
}
5443

55-
void set_index(T value, ...) {
44+
template<bool C = IsConst>
45+
typename std::enable_if<!C, void>::type set_index(T value, ...) {
5646
va_list args;
5747
va_start(args, value);
5848
int64_t ix = 0;
@@ -92,9 +82,9 @@ void forced_align_impl(
9282
backPtr_a[i] = -1;
9383
}
9484

95-
auto logProbs_a = Accessor<3, scalar_t>(logProbs);
96-
auto targets_a = Accessor<2, target_t>(targets);
97-
auto paths_a = MutAccessor<2, target_t>(paths);
85+
auto logProbs_a = Accessor<3, scalar_t, true>(logProbs);
86+
auto targets_a = Accessor<2, target_t, true>(targets);
87+
auto paths_a = Accessor<2, target_t, false>(paths);
9888
auto R = 0;
9989
for (auto i = 1; i < L; i++) {
10090
if (targets_a.index(batchIndex, i) == targets_a.index(batchIndex, i - 1)) {

0 commit comments

Comments
 (0)