@@ -33,11 +33,14 @@ void forced_align_impl(
33
33
alphas_a[i][1 ] = kNegInfinity ;
34
34
}
35
35
36
- torch::Tensor backPtr = torch::empty ({T, S}, torch::kInt8 ).fill_ (-1 );
36
+ auto backPtr_a = new int8_t [T * S];
37
+ for (int i = 0 ; i < T * S; i++) {
38
+ backPtr_a[i] = -1 ;
39
+ }
40
+
37
41
auto logProbs_a = logProbs.accessor <scalar_t , 3 >();
38
42
auto targets_a = targets.accessor <target_t , 2 >();
39
43
auto paths_a = paths.accessor <target_t , 2 >();
40
- auto backPtr_a = backPtr.accessor <int8_t , 2 >();
41
44
auto R = 0 ;
42
45
for (auto i = 1 ; i < L; i++) {
43
46
if (targets_a[batchIndex][i] == targets_a[batchIndex][i - 1 ]) {
@@ -84,7 +87,7 @@ void forced_align_impl(
84
87
if (start == 0 ) {
85
88
alphas_a[0 ][curIdxOffset] =
86
89
alphas_a[0 ][prevIdxOffset] + logProbs_a[batchIndex][t][blank];
87
- backPtr_a[t][ 0 ] = 0 ;
90
+ backPtr_a[S * t ] = 0 ;
88
91
startloop += 1 ;
89
92
}
90
93
@@ -106,13 +109,13 @@ void forced_align_impl(
106
109
scalar_t result = 0.0 ;
107
110
if (x2 > x1 && x2 > x0) {
108
111
result = x2;
109
- backPtr_a[t][ i] = 2 ;
112
+ backPtr_a[t * S + i] = 2 ;
110
113
} else if (x1 > x0 && x1 > x2) {
111
114
result = x1;
112
- backPtr_a[t][ i] = 1 ;
115
+ backPtr_a[t * S + i] = 1 ;
113
116
} else {
114
117
result = x0;
115
- backPtr_a[t][ i] = 0 ;
118
+ backPtr_a[t * S + i] = 0 ;
116
119
}
117
120
alphas_a[i][curIdxOffset] = result + logProbs_a[batchIndex][t][labelIdx];
118
121
}
@@ -123,7 +126,7 @@ void forced_align_impl(
123
126
for (auto t = T - 1 ; t > -1 ; t--) {
124
127
auto lbl_idx = ltrIdx % 2 == 0 ? blank : targets_a[batchIndex][ltrIdx / 2 ];
125
128
paths_a[batchIndex][t] = lbl_idx;
126
- ltrIdx -= backPtr_a[t][ ltrIdx];
129
+ ltrIdx -= backPtr_a[t * S + ltrIdx];
127
130
}
128
131
}
129
132
0 commit comments