1
1
#include < torch/script.h>
2
2
#include < torch/torch.h>
3
+ #include < torch/csrc/stable/library.h>
4
+ #include < torch/csrc/stable/tensor.h>
5
+ #include < torch/csrc/stable/ops.h>
6
+ #include < torch/csrc/inductor/aoti_torch/c/shim.h>
3
7
4
8
using namespace std ;
5
9
@@ -22,17 +26,16 @@ void forced_align_impl(
22
26
const auto T = logProbs.size (1 );
23
27
const auto L = targets.size (1 );
24
28
const auto S = 2 * L + 1 ;
25
- torch::Tensor alphas = torch::empty (
26
- { 2 , S},
27
- torch::TensorOptions ()
28
- . device (logProbs. device ())
29
- . dtype (logProbs. dtype ()))
30
- . fill_ ( kNegInfinity );
29
+
30
+ auto alphas_a = new scalar_t [ 2 * S]; // scalar_t is just logProbs.dtype()
31
+ for ( int i = 0 ; i < 2 * S; i++) {
32
+ alphas_a[i] = kNegInfinity ;
33
+ }
34
+
31
35
torch::Tensor backPtr = torch::empty ({T, S}, torch::kInt8 ).fill_ (-1 );
32
36
auto logProbs_a = logProbs.accessor <scalar_t , 3 >();
33
37
auto targets_a = targets.accessor <target_t , 2 >();
34
38
auto paths_a = paths.accessor <target_t , 2 >();
35
- auto alphas_a = alphas.accessor <scalar_t , 2 >();
36
39
auto backPtr_a = backPtr.accessor <int8_t , 2 >();
37
40
auto R = 0 ;
38
41
for (auto i = 1 ; i < L; i++) {
@@ -52,7 +55,7 @@ void forced_align_impl(
52
55
auto end = (S == 1 ) ? 1 : 2 ;
53
56
for (auto i = start; i < end; i++) {
54
57
auto labelIdx = (i % 2 == 0 ) ? blank : targets_a[batchIndex][i / 2 ];
55
- alphas_a[0 ][ i] = logProbs_a[batchIndex][0 ][labelIdx];
58
+ alphas_a[i] = logProbs_a[batchIndex][0 ][labelIdx]; // alphas_a[0, i]
56
59
}
57
60
for (auto t = 1 ; t < T; t++) {
58
61
if (T - t <= L + R) {
@@ -75,18 +78,18 @@ void forced_align_impl(
75
78
auto curIdxOffset = t % 2 ;
76
79
auto prevIdxOffset = (t - 1 ) % 2 ;
77
80
for (auto j = 0 ; j < S; ++j) {
78
- alphas_a[curIdxOffset][ j] = -std::numeric_limits<scalar_t >::infinity ();
81
+ alphas_a[curIdxOffset * S + j] = -std::numeric_limits<scalar_t >::infinity (); // alphas_a[curIdxOffset][j]
79
82
}
80
83
if (start == 0 ) {
81
- alphas_a[curIdxOffset][ 0 ] =
82
- alphas_a[prevIdxOffset][ 0 ] + logProbs_a[batchIndex][t][blank];
84
+ alphas_a[curIdxOffset * S ] =
85
+ alphas_a[prevIdxOffset * S ] + logProbs_a[batchIndex][t][blank];
83
86
backPtr_a[t][0 ] = 0 ;
84
87
startloop += 1 ;
85
88
}
86
89
87
90
for (auto i = startloop; i < end; i++) {
88
- auto x0 = alphas_a[prevIdxOffset][i];
89
- auto x1 = alphas_a[prevIdxOffset][i - 1 ];
91
+ auto x0 = alphas_a[prevIdxOffset * S + i]; // alphas_a[prevIdxOffset ][i];
92
+ auto x1 = alphas_a[prevIdxOffset * S + i - 1 ]; // alphas_a[prevIdxOffset ][i - 1];
90
93
auto x2 = -std::numeric_limits<scalar_t >::infinity ();
91
94
92
95
auto labelIdx = (i % 2 == 0 ) ? blank : targets_a[batchIndex][i / 2 ];
@@ -97,7 +100,7 @@ void forced_align_impl(
97
100
// (i != 1) just ensures we don't access targets[i - 2] if its i < 2
98
101
if (i % 2 != 0 && i != 1 &&
99
102
targets_a[batchIndex][i / 2 ] != targets_a[batchIndex][i / 2 - 1 ]) {
100
- x2 = alphas_a[prevIdxOffset][i - 2 ];
103
+ x2 = alphas_a[prevIdxOffset * S + i - 2 ]; // alphas_a[prevIdxOffset ][i - 2];
101
104
}
102
105
scalar_t result = 0.0 ;
103
106
if (x2 > x1 && x2 > x0) {
@@ -110,11 +113,13 @@ void forced_align_impl(
110
113
result = x0;
111
114
backPtr_a[t][i] = 0 ;
112
115
}
113
- alphas_a[curIdxOffset][ i] = result + logProbs_a[batchIndex][t][labelIdx];
116
+ alphas_a[curIdxOffset * S + i] = result + logProbs_a[batchIndex][t][labelIdx]; // alphas_a[curIdxOffset][i]
114
117
}
115
118
}
116
119
auto idx1 = (T - 1 ) % 2 ;
117
- auto ltrIdx = alphas_a[idx1][S - 1 ] > alphas_a[idx1][S - 2 ] ? S - 1 : S - 2 ;
120
+ auto ltrIdx = alphas_a[S * idx1 + S - 1 ] >
121
+ alphas_a[S * idx1 + S - 2 ] ? S - 1 : S - 2 ; // alphas_a[idx1][S - 1], alphas_a[idx1][S - 2]
122
+ delete[] alphas_a;
118
123
// path stores the token index for each time step after force alignment.
119
124
for (auto t = T - 1 ; t > -1 ; t--) {
120
125
auto lbl_idx = ltrIdx % 2 == 0 ? blank : targets_a[batchIndex][ltrIdx / 2 ];
0 commit comments