Skip to content

Commit b7937d2

Browse files
authored
[cherry pick] Fix the integer overflow problem of sequence2batch. (#22479)
cherry-pick from the branch develop,fix the overflow of sequence2batch
1 parent 61ec75c commit b7937d2

File tree

1 file changed

+12
-12
lines changed

1 file changed

+12
-12
lines changed

paddle/fluid/operators/math/sequence2batch.h

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,11 @@ class LoDTensor2BatchFunctor {
5050
// seq_info[3] = {(4, 5, 1), (0, 4, 0), (9, 3, 2)}
5151
//
5252
struct SeqInfo {
53-
SeqInfo(int start, int length, int seq_idx)
53+
SeqInfo(size_t start, size_t length, size_t seq_idx)
5454
: start(start), length(length), seq_idx(seq_idx) {}
55-
int start;
56-
int length;
57-
int seq_idx;
55+
size_t start;
56+
size_t length;
57+
size_t seq_idx;
5858
};
5959

6060
public:
@@ -82,7 +82,7 @@ class LoDTensor2BatchFunctor {
8282

8383
std::vector<SeqInfo> seq_info;
8484
for (size_t seq_id = 0; seq_id < lod.size() - 1; ++seq_id) {
85-
int length = lod[seq_id + 1] - lod[seq_id];
85+
size_t length = lod[seq_id + 1] - lod[seq_id];
8686
seq_info.emplace_back(lod[seq_id], length, seq_id);
8787
}
8888

@@ -118,8 +118,8 @@ class LoDTensor2BatchFunctor {
118118
batch_lods.emplace_back(std::vector<size_t>{0});
119119

120120
// batch_lods[0] is the start positions for batch LoDTensor
121-
int max_seqlen = seq_info[0].length;
122-
batch_lods[0].resize(static_cast<size_t>(max_seqlen + 1));
121+
size_t max_seqlen = seq_info[0].length;
122+
batch_lods[0].resize(max_seqlen + 1);
123123
// batch_lods[1] is the raw index in the input LoDTensor
124124
batch_lods[1].resize(static_cast<size_t>(lod_tensor.dims()[0]));
125125
// batch_lods[2] is the sort order for the input LoDTensor.
@@ -128,11 +128,11 @@ class LoDTensor2BatchFunctor {
128128
size_t* batch_starts = batch_lods[0].data();
129129
size_t* seq2batch_idx = batch_lods[1].data();
130130
batch_starts[0] = 0;
131-
for (int n = 0; n < max_seqlen; n++) {
132-
auto batch_id = static_cast<int>(batch_starts[n]);
131+
for (size_t n = 0; n < max_seqlen; n++) {
132+
size_t batch_id = batch_starts[n];
133133
for (size_t i = 0; i < seq_info.size(); ++i) {
134-
int seq_len = seq_info[i].length;
135-
int start = seq_info[i].start;
134+
size_t seq_len = seq_info[i].length;
135+
size_t start = seq_info[i].start;
136136
if (n < seq_len) {
137137
seq2batch_idx[batch_id] =
138138
is_reverse ? start + seq_len - 1 - n : start + n;
@@ -141,7 +141,7 @@ class LoDTensor2BatchFunctor {
141141
break;
142142
}
143143
}
144-
batch_starts[n + 1] = static_cast<size_t>(batch_id);
144+
batch_starts[n + 1] = batch_id;
145145
}
146146
size_t* seq_order = batch_lods[2].data();
147147
for (size_t i = 0; i < seq_info.size(); ++i) {

0 commit comments

Comments
 (0)