@@ -32,11 +32,14 @@ void forced_align_impl(
32
32
alphas_a[i] = kNegInfinity ;
33
33
}
34
34
35
- torch::Tensor backPtr = torch::empty ({T, S}, torch::kInt8 ).fill_ (-1 );
35
+ auto backPtr_a = new int8_t [T * S];
36
+ for (int i = 0 ; i < T * S; i++) {
37
+ backPtr_a[i] = -1 ;
38
+ }
39
+
36
40
auto logProbs_a = logProbs.accessor <scalar_t , 3 >();
37
41
auto targets_a = targets.accessor <target_t , 2 >();
38
42
auto paths_a = paths.accessor <target_t , 2 >();
39
- auto backPtr_a = backPtr.accessor <int8_t , 2 >();
40
43
auto R = 0 ;
41
44
for (auto i = 1 ; i < L; i++) {
42
45
if (targets_a[batchIndex][i] == targets_a[batchIndex][i - 1 ]) {
@@ -82,8 +85,8 @@ void forced_align_impl(
82
85
}
83
86
if (start == 0 ) {
84
87
alphas_a[curIdxOffset * S] =
85
- alphas_a[prevIdxOffset * S] + logProbs_a[batchIndex][t][blank];
86
- backPtr_a[t][ 0 ] = 0 ;
88
+ alphas_a[prevIdxOffset * S] + logProbs_a[batchIndex][t][blank]; // alphas_a[curIdxOffset][0]
89
+ backPtr_a[S * t] = 0 ; // backPtr_a[t][ 0] = 0
87
90
startloop += 1 ;
88
91
}
89
92
@@ -105,13 +108,13 @@ void forced_align_impl(
105
108
scalar_t result = 0.0 ;
106
109
if (x2 > x1 && x2 > x0) {
107
110
result = x2;
108
- backPtr_a[t][ i] = 2 ;
111
+ backPtr_a[t * S + i] = 2 ; // backPtr_a[t][ i] = 2
109
112
} else if (x1 > x0 && x1 > x2) {
110
113
result = x1;
111
- backPtr_a[t][ i] = 1 ;
114
+ backPtr_a[t * S + i] = 1 ; // backPtr_a[t][ i] = 1
112
115
} else {
113
116
result = x0;
114
- backPtr_a[t][ i] = 0 ;
117
+ backPtr_a[t * S + i] = 0 ; // backPtr_a[t][ i] = 0
115
118
}
116
119
alphas_a[curIdxOffset * S + i] = result + logProbs_a[batchIndex][t][labelIdx]; // alphas_a[curIdxOffset][i]
117
120
}
@@ -124,8 +127,9 @@ void forced_align_impl(
124
127
for (auto t = T - 1 ; t > -1 ; t--) {
125
128
auto lbl_idx = ltrIdx % 2 == 0 ? blank : targets_a[batchIndex][ltrIdx / 2 ];
126
129
paths_a[batchIndex][t] = lbl_idx;
127
- ltrIdx -= backPtr_a[t][ ltrIdx];
130
+ ltrIdx -= backPtr_a[t * S + ltrIdx]; // backPtr_a[t][ ltrIdx]
128
131
}
132
+ delete[] backPtr_a;
129
133
}
130
134
131
135
std::tuple<torch::Tensor, torch::Tensor> compute (
0 commit comments