Skip to content

Commit 7a94b04

Browse files
committed
Move Accessor to its own file and add tests
1 parent b47c053 commit 7a94b04

File tree

5 files changed

+78
-39
lines changed

5 files changed

+78
-39
lines changed

src/libtorchaudio/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ set(
66
lfilter.cpp
77
overdrive.cpp
88
utils.cpp
9+
accessor_tests.cpp
910
)
1011

1112
set(

src/libtorchaudio/accessor.h

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
#pragma once
2+
3+
#include <torch/torch.h>
4+
#include <type_traits>
5+
#include <cstdarg>
6+
7+
template<unsigned int k, typename T, bool IsConst = true>
8+
class Accessor {
9+
int64_t strides[k];
10+
T *data;
11+
12+
public:
13+
using tensor_type = typename std::conditional<IsConst, const torch::Tensor&, torch::Tensor&>::type;
14+
15+
Accessor(tensor_type tensor) {
16+
data = tensor.template data_ptr<T>();
17+
for (int i = 0; i < k; i++) {
18+
strides[i] = tensor.stride(i);
19+
}
20+
}
21+
22+
T index(...) {
23+
va_list args;
24+
va_start(args, k);
25+
int64_t ix = 0;
26+
for (int i = 0; i < k; i++) {
27+
ix += strides[i] * va_arg(args, int);
28+
}
29+
va_end(args);
30+
return data[ix];
31+
}
32+
33+
template<bool C = IsConst>
34+
typename std::enable_if<!C, void>::type set_index(T value, ...) {
35+
va_list args;
36+
va_start(args, value);
37+
int64_t ix = 0;
38+
for (int i = 0; i < k; i++) {
39+
ix += strides[i] * va_arg(args, int);
40+
}
41+
va_end(args);
42+
data[ix] = value;
43+
}
44+
};
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
#include <libtorchaudio/accessor.h>
2+
#include <cstdint>
3+
#include <torch/torch.h>
4+
5+
using namespace std;
6+
7+
bool test_accessor(const torch::Tensor& tensor) {
8+
int64_t* data_ptr = tensor.template data_ptr<int64_t>();
9+
auto accessor = Accessor<3, int64_t>(tensor);
10+
for (int i = 0; i < tensor.size(0); i++) {
11+
for (int j = 0; j < tensor.size(1); j++) {
12+
for (int k = 0; k < tensor.size(2); k++) {
13+
auto check = *(data_ptr++) == accessor.index(i, j, k);
14+
if (!check) {
15+
return false;
16+
}
17+
}
18+
}
19+
}
20+
return true;
21+
}
22+
23+
TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
24+
m.def("torchaudio::_test_accessor", &test_accessor);
25+
}

src/libtorchaudio/forced_align/cpu/compute.cpp

Lines changed: 1 addition & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,7 @@
55
#include <torch/csrc/stable/ops.h>
66
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
77
#include <torch/csrc/inductor/aoti_torch/utils.h>
8-
#include <cstdarg>
9-
#include <type_traits>
8+
#include <libtorchaudio/accessor.h>
109

1110

1211
using namespace std;
@@ -15,44 +14,7 @@ namespace torchaudio {
1514
namespace alignment {
1615
namespace cpu {
1716

18-
template<unsigned int k, typename T, bool IsConst = true>
19-
class Accessor {
20-
int64_t strides[k];
21-
T *data;
22-
23-
public:
24-
using tensor_type = typename std::conditional<IsConst, const torch::Tensor&, torch::Tensor&>::type;
25-
26-
Accessor(tensor_type tensor) {
27-
data = tensor.template data_ptr<T>();
28-
for (int i = 0; i < k; i++) {
29-
strides[i] = tensor.stride(i);
30-
}
31-
}
3217

33-
T index(...) {
34-
va_list args;
35-
va_start(args, k);
36-
int64_t ix = 0;
37-
for (int i = 0; i < k; i++) {
38-
ix += strides[i] * va_arg(args, int);
39-
}
40-
va_end(args);
41-
return data[ix];
42-
}
43-
44-
template<bool C = IsConst>
45-
typename std::enable_if<!C, void>::type set_index(T value, ...) {
46-
va_list args;
47-
va_start(args, value);
48-
int64_t ix = 0;
49-
for (int i = 0; i < k; i++) {
50-
ix += strides[i] * va_arg(args, int);
51-
}
52-
va_end(args);
53-
data[ix] = value;
54-
}
55-
};
5618

5719
// Inspired from
5820
// https://github.com/flashlight/sequence/blob/main/flashlight/lib/sequence/criterion/cpu/ConnectionistTemporalClassificationCriterion.cpp
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
import torch
2+
from torchaudio._extension import _IS_TORCHAUDIO_EXT_AVAILABLE
3+
4+
if _IS_TORCHAUDIO_EXT_AVAILABLE:
5+
def test_accessor():
6+
tensor = torch.randint(1000, (5,4,3))
7+
assert torch.ops.torchaudio._test_accessor(tensor)

0 commit comments

Comments
 (0)