@@ -40,22 +40,38 @@ class Accessor {
40
40
va_end (args);
41
41
return data[ix];
42
42
}
43
+ };
44
+
45
+
46
+ template <unsigned int k, typename T>
47
+ class MutAccessor {
48
+ int64_t shape[k];
49
+ T *data;
50
+
51
+ public:
52
+ MutAccessor (torch::Tensor& tensor) {
53
+ data = tensor.data_ptr <T>();
54
+ for (int i = 0 ; i < k; i++) {
55
+ shape[i] = tensor.size (i);
56
+ }
57
+ }
43
58
44
- // void set_index(T val ,...) {
45
- // va_list args;
46
- // va_start(args, k);
47
- // int64_t ix = 0;
48
- // for (int i = 0; i < k; i++) {
49
- // if (i == k - 1)
50
- // ix += va_arg(args, int);
51
- // else
52
- // ix += shape[i+1] * va_arg(args, int);
53
- // }
54
- // va_end(args);
55
- // data[ix] = val ;
56
- // }
59
+ void set_index (T value ,...) {
60
+ va_list args;
61
+ va_start (args, k);
62
+ int64_t ix = 0 ;
63
+ for (int i = 0 ; i < k; i++) {
64
+ if (i == k - 1 )
65
+ ix += va_arg (args, int );
66
+ else
67
+ ix += shape[i+1 ] * va_arg (args, int );
68
+ }
69
+ va_end (args);
70
+ data[ix] = value ;
71
+ }
57
72
};
58
73
74
+
59
75
// Inspired from
60
76
// https://github.com/flashlight/sequence/blob/main/flashlight/lib/sequence/criterion/cpu/ConnectionistTemporalClassificationCriterion.cpp
61
77
template <typename scalar_t , at::ScalarType target_scalar_type>
@@ -86,7 +102,7 @@ void forced_align_impl(
86
102
87
103
auto logProbs_a = Accessor<3 , scalar_t >(logProbs);
88
104
auto targets_a = Accessor<2 , target_t >(targets);
89
- auto paths_a = paths. accessor < target_t , 2 >( );
105
+ auto paths_a = MutAccessor< 2 , target_t >(paths );
90
106
auto R = 0 ;
91
107
for (auto i = 1 ; i < L; i++) {
92
108
if (targets_a.index (batchIndex, i) == targets_a.index (batchIndex, i - 1 )) {
@@ -171,7 +187,7 @@ void forced_align_impl(
171
187
// path stores the token index for each time step after force alignment.
172
188
for (auto t = T - 1 ; t > -1 ; t--) {
173
189
auto lbl_idx = ltrIdx % 2 == 0 ? blank : targets_a.index (batchIndex, ltrIdx / 2 );
174
- paths_a[ batchIndex][t] = lbl_idx ;
190
+ paths_a. set_index (lbl_idx, batchIndex, t) ;
175
191
ltrIdx -= backPtr_a[t * S + ltrIdx];
176
192
}
177
193
}
0 commit comments