@@ -27,10 +27,9 @@ void forced_align_impl(
27
27
const auto L = targets.size (1 );
28
28
const auto S = 2 * L + 1 ;
29
29
30
- auto alphas_a = new scalar_t [S][2 ]; // scalar_t is just logProbs.dtype()
31
- for (int i = 0 ; i < S; i++) {
32
- alphas_a[i][0 ] = kNegInfinity ;
33
- alphas_a[i][1 ] = kNegInfinity ;
30
+ auto alphas_a = new scalar_t [2 * S]; // scalar_t is just logProbs.dtype()
31
+ for (int i = 0 ; i < 2 * S; i++) {
32
+ alphas_a[i] = kNegInfinity ;
34
33
}
35
34
36
35
auto backPtr_a = new int8_t [T * S];
@@ -59,7 +58,7 @@ void forced_align_impl(
59
58
auto end = (S == 1 ) ? 1 : 2 ;
60
59
for (auto i = start; i < end; i++) {
61
60
auto labelIdx = (i % 2 == 0 ) ? blank : targets_a[batchIndex][i / 2 ];
62
- alphas_a[i][ 0 ] = logProbs_a[batchIndex][0 ][labelIdx];
61
+ alphas_a[i] = logProbs_a[batchIndex][0 ][labelIdx]; // alphas_a[0, i]
63
62
}
64
63
for (auto t = 1 ; t < T; t++) {
65
64
if (T - t <= L + R) {
@@ -82,18 +81,18 @@ void forced_align_impl(
82
81
auto curIdxOffset = t % 2 ;
83
82
auto prevIdxOffset = (t - 1 ) % 2 ;
84
83
for (auto j = 0 ; j < S; ++j) {
85
- alphas_a[j][curIdxOffset] = -std::numeric_limits<scalar_t >::infinity ();
84
+ alphas_a[curIdxOffset * S + j] = -std::numeric_limits<scalar_t >::infinity (); // alphas_a[curIdxOffset][j]
86
85
}
87
86
if (start == 0 ) {
88
- alphas_a[0 ][ curIdxOffset] =
89
- alphas_a[0 ][ prevIdxOffset] + logProbs_a[batchIndex][t][blank];
87
+ alphas_a[curIdxOffset * S ] =
88
+ alphas_a[prevIdxOffset * S ] + logProbs_a[batchIndex][t][blank]; // alphas_a[curIdxOffset][0]
90
89
backPtr_a[S * t] = 0 ; // backPtr_a[t][0] = 0
91
90
startloop += 1 ;
92
91
}
93
92
94
93
for (auto i = startloop; i < end; i++) {
95
- auto x0 = alphas_a[i] [prevIdxOffset];
96
- auto x1 = alphas_a[i - 1 ][prevIdxOffset];
94
+ auto x0 = alphas_a[prevIdxOffset * S + i]; // alphas_a [prevIdxOffset][i ];
95
+ auto x1 = alphas_a[prevIdxOffset * S + i - 1 ]; // alphas_a [prevIdxOffset][i - 1 ];
97
96
auto x2 = -std::numeric_limits<scalar_t >::infinity ();
98
97
99
98
auto labelIdx = (i % 2 == 0 ) ? blank : targets_a[batchIndex][i / 2 ];
@@ -104,7 +103,7 @@ void forced_align_impl(
104
103
// (i != 1) just ensures we don't access targets[i - 2] if its i < 2
105
104
if (i % 2 != 0 && i != 1 &&
106
105
targets_a[batchIndex][i / 2 ] != targets_a[batchIndex][i / 2 - 1 ]) {
107
- x2 = alphas_a[i - 2 ][prevIdxOffset];
106
+ x2 = alphas_a[prevIdxOffset * S + i - 2 ]; // alphas_a [prevIdxOffset][i - 2 ];
108
107
}
109
108
scalar_t result = 0.0 ;
110
109
if (x2 > x1 && x2 > x0) {
@@ -117,11 +116,12 @@ void forced_align_impl(
117
116
result = x0;
118
117
backPtr_a[t * S + i] = 0 ; // backPtr_a[t][i] = 0
119
118
}
120
- alphas_a[i][curIdxOffset] = result + logProbs_a[batchIndex][t][labelIdx];
119
+ alphas_a[curIdxOffset * S + i] = result + logProbs_a[batchIndex][t][labelIdx]; // alphas_a[curIdxOffset][i]
121
120
}
122
121
}
123
122
auto idx1 = (T - 1 ) % 2 ;
124
- auto ltrIdx = alphas_a[S - 1 ][idx1] > alphas_a[S - 2 ][idx1] ? S - 1 : S - 2 ;
123
+ auto ltrIdx = alphas_a[S * idx1 + S - 1 ] >
124
+ alphas_a[S * idx1 + S - 2 ] ? S - 1 : S - 2 ; // alphas_a[idx1][S - 1], alphas_a[idx1][S - 2]
125
125
delete[] alphas_a;
126
126
// path stores the token index for each time step after force alignment.
127
127
for (auto t = T - 1 ; t > -1 ; t--) {
0 commit comments