@@ -50,11 +50,11 @@ class LoDTensor2BatchFunctor {
50
50
// seq_info[3] = {(4, 5, 1), (0, 4, 0), (9, 3, 2)}
51
51
//
52
52
struct SeqInfo {
53
- SeqInfo (int start, int length, int seq_idx)
53
+ SeqInfo (size_t start, size_t length, size_t seq_idx)
54
54
: start(start), length(length), seq_idx(seq_idx) {}
55
- int start;
56
- int length;
57
- int seq_idx;
55
+ size_t start;
56
+ size_t length;
57
+ size_t seq_idx;
58
58
};
59
59
60
60
public:
@@ -82,7 +82,7 @@ class LoDTensor2BatchFunctor {
82
82
83
83
std::vector<SeqInfo> seq_info;
84
84
for (size_t seq_id = 0 ; seq_id < lod.size () - 1 ; ++seq_id) {
85
- int length = lod[seq_id + 1 ] - lod[seq_id];
85
+ size_t length = lod[seq_id + 1 ] - lod[seq_id];
86
86
seq_info.emplace_back (lod[seq_id], length, seq_id);
87
87
}
88
88
@@ -118,8 +118,8 @@ class LoDTensor2BatchFunctor {
118
118
batch_lods.emplace_back (std::vector<size_t >{0 });
119
119
120
120
// batch_lods[0] is the start positions for batch LoDTensor
121
- int max_seqlen = seq_info[0 ].length ;
122
- batch_lods[0 ].resize (static_cast < size_t >( max_seqlen + 1 ) );
121
+ size_t max_seqlen = seq_info[0 ].length ;
122
+ batch_lods[0 ].resize (max_seqlen + 1 );
123
123
// batch_lods[1] is the raw index in the input LoDTensor
124
124
batch_lods[1 ].resize (static_cast <size_t >(lod_tensor.dims ()[0 ]));
125
125
// batch_lods[2] is the sort order for the input LoDTensor.
@@ -128,11 +128,11 @@ class LoDTensor2BatchFunctor {
128
128
size_t * batch_starts = batch_lods[0 ].data ();
129
129
size_t * seq2batch_idx = batch_lods[1 ].data ();
130
130
batch_starts[0 ] = 0 ;
131
- for (int n = 0 ; n < max_seqlen; n++) {
132
- auto batch_id = static_cast < int >( batch_starts[n]) ;
131
+ for (size_t n = 0 ; n < max_seqlen; n++) {
132
+ size_t batch_id = batch_starts[n];
133
133
for (size_t i = 0 ; i < seq_info.size (); ++i) {
134
- int seq_len = seq_info[i].length ;
135
- int start = seq_info[i].start ;
134
+ size_t seq_len = seq_info[i].length ;
135
+ size_t start = seq_info[i].start ;
136
136
if (n < seq_len) {
137
137
seq2batch_idx[batch_id] =
138
138
is_reverse ? start + seq_len - 1 - n : start + n;
@@ -141,7 +141,7 @@ class LoDTensor2BatchFunctor {
141
141
break ;
142
142
}
143
143
}
144
- batch_starts[n + 1 ] = static_cast < size_t >( batch_id) ;
144
+ batch_starts[n + 1 ] = batch_id;
145
145
}
146
146
size_t * seq_order = batch_lods[2 ].data ();
147
147
for (size_t i = 0 ; i < seq_info.size (); ++i) {
0 commit comments