4
4
#include < torch/csrc/stable/tensor.h>
5
5
#include < torch/csrc/stable/ops.h>
6
6
#include < torch/csrc/inductor/aoti_torch/c/shim.h>
7
+ #include < torch/csrc/inductor/aoti_torch/utils.h>
8
+ #include < cstdarg>
9
+
7
10
8
11
using namespace std ;
9
12
10
13
namespace torchaudio {
11
14
namespace alignment {
12
15
namespace cpu {
16
+
17
+ template <unsigned int k, typename T>
18
+ class Accessor {
19
+ int64_t shape[k];
20
+ T *data;
21
+
22
+ public:
23
+ Accessor (const torch::Tensor& tensor) {
24
+ data = tensor.data_ptr <T>();
25
+ for (int i = 0 ; i < k; i++) {
26
+ shape[i] = tensor.size (i);
27
+ }
28
+ }
29
+
30
+ T index (...) {
31
+ va_list args;
32
+ va_start (args, k);
33
+ int64_t ix = 0 ;
34
+ for (int i = 0 ; i < k; i++) {
35
+ if (i == k - 1 )
36
+ ix += va_arg (args, int );
37
+ else
38
+ ix += shape[i+1 ] * va_arg (args, int );
39
+ }
40
+ va_end (args);
41
+ return data[ix];
42
+ }
43
+
44
+ // void set_index(T val,...) {
45
+ // va_list args;
46
+ // va_start(args, k);
47
+ // int64_t ix = 0;
48
+ // for (int i = 0; i < k; i++) {
49
+ // if (i == k - 1)
50
+ // ix += va_arg(args, int);
51
+ // else
52
+ // ix += shape[i+1] * va_arg(args, int);
53
+ // }
54
+ // va_end(args);
55
+ // data[ix] = val;
56
+ // }
57
+ };
58
+
13
59
// Inspired from
14
60
// https://github.com/flashlight/sequence/blob/main/flashlight/lib/sequence/criterion/cpu/ConnectionistTemporalClassificationCriterion.cpp
15
61
template <typename scalar_t , at::ScalarType target_scalar_type>
@@ -38,12 +84,12 @@ void forced_align_impl(
38
84
backPtr_a[i] = -1 ;
39
85
}
40
86
41
- auto logProbs_a = logProbs. accessor < scalar_t , 3 >( );
42
- auto targets_a = targets. accessor < target_t , 2 >( );
87
+ auto logProbs_a = Accessor< 3 , scalar_t >(logProbs );
88
+ auto targets_a = Accessor< 2 , target_t >(targets );
43
89
auto paths_a = paths.accessor <target_t , 2 >();
44
90
auto R = 0 ;
45
91
for (auto i = 1 ; i < L; i++) {
46
- if (targets_a[ batchIndex][i] == targets_a[ batchIndex][ i - 1 ] ) {
92
+ if (targets_a. index ( batchIndex, i) == targets_a. index ( batchIndex, i - 1 ) ) {
47
93
++R;
48
94
}
49
95
}
@@ -58,22 +104,22 @@ void forced_align_impl(
58
104
auto start = T - (L + R) > 0 ? 0 : 1 ;
59
105
auto end = (S == 1 ) ? 1 : 2 ;
60
106
for (auto i = start; i < end; i++) {
61
- auto labelIdx = (i % 2 == 0 ) ? blank : targets_a[ batchIndex][ i / 2 ] ;
62
- alphas_a[i][0 ] = logProbs_a[ batchIndex][ 0 ][ labelIdx] ;
107
+ auto labelIdx = (i % 2 == 0 ) ? blank : targets_a. index ( batchIndex, i / 2 ) ;
108
+ alphas_a[i][0 ] = logProbs_a. index ( batchIndex, 0 , labelIdx) ;
63
109
}
64
110
for (auto t = 1 ; t < T; t++) {
65
111
if (T - t <= L + R) {
66
112
if ((start % 2 == 1 ) &&
67
- targets_a[ batchIndex][ start / 2 ] !=
68
- targets_a[ batchIndex][ start / 2 + 1 ] ) {
113
+ targets_a. index ( batchIndex, start / 2 ) !=
114
+ targets_a. index ( batchIndex, start / 2 + 1 ) ) {
69
115
start = start + 1 ;
70
116
}
71
117
start = start + 1 ;
72
118
}
73
119
if (t <= L + R) {
74
120
if (end % 2 == 0 && end < 2 * L &&
75
- targets_a[ batchIndex][ end / 2 - 1 ] !=
76
- targets_a[ batchIndex][ end / 2 ] ) {
121
+ targets_a. index ( batchIndex, end / 2 - 1 ) !=
122
+ targets_a. index ( batchIndex, end / 2 ) ) {
77
123
end = end + 1 ;
78
124
}
79
125
end = end + 1 ;
@@ -86,7 +132,7 @@ void forced_align_impl(
86
132
}
87
133
if (start == 0 ) {
88
134
alphas_a[0 ][curIdxOffset] =
89
- alphas_a[0 ][prevIdxOffset] + logProbs_a[ batchIndex][t][ blank] ;
135
+ alphas_a[0 ][prevIdxOffset] + logProbs_a. index ( batchIndex, t, blank) ;
90
136
backPtr_a[S * t] = 0 ;
91
137
startloop += 1 ;
92
138
}
@@ -96,14 +142,14 @@ void forced_align_impl(
96
142
auto x1 = alphas_a[i - 1 ][prevIdxOffset];
97
143
auto x2 = -std::numeric_limits<scalar_t >::infinity ();
98
144
99
- auto labelIdx = (i % 2 == 0 ) ? blank : targets_a[ batchIndex][ i / 2 ] ;
145
+ auto labelIdx = (i % 2 == 0 ) ? blank : targets_a. index ( batchIndex, i / 2 ) ;
100
146
101
147
// In CTC, the optimal path may optionally chose to skip a blank label.
102
148
// x2 represents skipping a letter, and can only happen if we're not
103
149
// currently on a blank_label, and we're not on a repeat letter
104
150
// (i != 1) just ensures we don't access targets[i - 2] if its i < 2
105
151
if (i % 2 != 0 && i != 1 &&
106
- targets_a[ batchIndex][ i / 2 ] != targets_a[ batchIndex][ i / 2 - 1 ] ) {
152
+ targets_a. index ( batchIndex, i / 2 ) != targets_a. index ( batchIndex, i / 2 - 1 ) ) {
107
153
x2 = alphas_a[i - 2 ][prevIdxOffset];
108
154
}
109
155
scalar_t result = 0.0 ;
@@ -117,14 +163,14 @@ void forced_align_impl(
117
163
result = x0;
118
164
backPtr_a[t * S + i] = 0 ;
119
165
}
120
- alphas_a[i][curIdxOffset] = result + logProbs_a[ batchIndex][t][ labelIdx] ;
166
+ alphas_a[i][curIdxOffset] = result + logProbs_a. index ( batchIndex, t, labelIdx) ;
121
167
}
122
168
}
123
169
auto idx1 = (T - 1 ) % 2 ;
124
170
auto ltrIdx = alphas_a[S - 1 ][idx1] > alphas_a[S - 2 ][idx1] ? S - 1 : S - 2 ;
125
171
// path stores the token index for each time step after force alignment.
126
172
for (auto t = T - 1 ; t > -1 ; t--) {
127
- auto lbl_idx = ltrIdx % 2 == 0 ? blank : targets_a[ batchIndex][ ltrIdx / 2 ] ;
173
+ auto lbl_idx = ltrIdx % 2 == 0 ? blank : targets_a. index ( batchIndex, ltrIdx / 2 ) ;
128
174
paths_a[batchIndex][t] = lbl_idx;
129
175
ltrIdx -= backPtr_a[t * S + ltrIdx];
130
176
}
0 commit comments