1
- #include < torch/script.h>
2
- #include < torch/torch.h>
1
+ #include < torch/csrc/inductor/aoti_torch/c/shim.h>
3
2
#include < torch/csrc/stable/library.h>
4
- #include < torch/csrc/stable/tensor.h>
5
3
#include < torch/csrc/stable/ops.h>
6
- #include < torch/csrc/inductor/aoti_torch/c/shim.h>
4
+ #include < torch/csrc/stable/tensor.h>
5
+ #include < torch/script.h>
6
+ #include < torch/torch.h>
7
7
8
8
using namespace std ;
9
9
@@ -81,18 +81,21 @@ void forced_align_impl(
81
81
auto curIdxOffset = t % 2 ;
82
82
auto prevIdxOffset = (t - 1 ) % 2 ;
83
83
for (auto j = 0 ; j < S; ++j) {
84
- alphas_a[curIdxOffset * S + j] = -std::numeric_limits<scalar_t >::infinity (); // alphas_a[curIdxOffset][j]
84
+ alphas_a[curIdxOffset * S + j] = -std::numeric_limits<
85
+ scalar_t >::infinity (); // alphas_a[curIdxOffset][j]
85
86
}
86
87
if (start == 0 ) {
87
- alphas_a[curIdxOffset * S] =
88
- alphas_a[prevIdxOffset * S] + logProbs_a[batchIndex][t][blank]; // alphas_a[curIdxOffset][0]
88
+ alphas_a[curIdxOffset * S] = alphas_a[prevIdxOffset * S] +
89
+ logProbs_a[batchIndex][t][blank]; // alphas_a[curIdxOffset][0]
89
90
backPtr_a[S * t] = 0 ; // backPtr_a[t][0] = 0
90
91
startloop += 1 ;
91
92
}
92
93
93
94
for (auto i = startloop; i < end; i++) {
94
95
auto x0 = alphas_a[prevIdxOffset * S + i]; // alphas_a[prevIdxOffset][i];
95
- auto x1 = alphas_a[prevIdxOffset * S + i - 1 ]; // alphas_a[prevIdxOffset][i - 1];
96
+ auto x1 =
97
+ alphas_a[prevIdxOffset * S + i - 1 ]; // alphas_a[prevIdxOffset][i
98
+ // - 1];
96
99
auto x2 = -std::numeric_limits<scalar_t >::infinity ();
97
100
98
101
auto labelIdx = (i % 2 == 0 ) ? blank : targets_a[batchIndex][i / 2 ];
@@ -103,7 +106,8 @@ void forced_align_impl(
103
106
// (i != 1) just ensures we don't access targets[i - 2] if its i < 2
104
107
if (i % 2 != 0 && i != 1 &&
105
108
targets_a[batchIndex][i / 2 ] != targets_a[batchIndex][i / 2 - 1 ]) {
106
- x2 = alphas_a[prevIdxOffset * S + i - 2 ]; // alphas_a[prevIdxOffset][i - 2];
109
+ x2 = alphas_a[prevIdxOffset * S + i - 2 ]; // alphas_a[prevIdxOffset][i -
110
+ // 2];
107
111
}
108
112
scalar_t result = 0.0 ;
109
113
if (x2 > x1 && x2 > x0) {
@@ -116,12 +120,14 @@ void forced_align_impl(
116
120
result = x0;
117
121
backPtr_a[t * S + i] = 0 ; // backPtr_a[t][i] = 0
118
122
}
119
- alphas_a[curIdxOffset * S + i] = result + logProbs_a[batchIndex][t][labelIdx]; // alphas_a[curIdxOffset][i]
123
+ alphas_a[curIdxOffset * S + i] = result +
124
+ logProbs_a[batchIndex][t][labelIdx]; // alphas_a[curIdxOffset][i]
120
125
}
121
126
}
122
127
auto idx1 = (T - 1 ) % 2 ;
123
- auto ltrIdx = alphas_a[S * idx1 + S - 1 ] >
124
- alphas_a[S * idx1 + S - 2 ] ? S - 1 : S - 2 ; // alphas_a[idx1][S - 1], alphas_a[idx1][S - 2]
128
+ auto ltrIdx = alphas_a[S * idx1 + S - 1 ] > alphas_a[S * idx1 + S - 2 ]
129
+ ? S - 1
130
+ : S - 2 ; // alphas_a[idx1][S - 1], alphas_a[idx1][S - 2]
125
131
delete[] alphas_a;
126
132
// path stores the token index for each time step after force alignment.
127
133
for (auto t = T - 1 ; t > -1 ; t--) {
@@ -194,15 +200,9 @@ std::tuple<torch::Tensor, torch::Tensor> compute(
194
200
logProbs, targets, blank, paths);
195
201
}
196
202
});
197
- return std::make_tuple (
198
- paths,
199
- logProbs
200
- );
203
+ return std::make_tuple (paths, logProbs);
201
204
}
202
205
203
-
204
-
205
-
206
206
TORCH_LIBRARY_IMPL (torchaudio, CPU, m) {
207
207
m.impl (" forced_align" , &compute);
208
208
}
0 commit comments