Skip to content

Commit 1454cd5

Browse files
committed
pre-commit check
1 parent 7429067 commit 1454cd5

File tree

1 file changed

+31
-33
lines changed

1 file changed

+31
-33
lines changed

paddle/fluid/operators/math/sequence_pooling.cc

Lines changed: 31 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -109,40 +109,38 @@ class LastFirstSeqPoolFunctor {
109109
void operator()(const platform::CPUDeviceContext& context,
110110
const framework::LoDTensor& input, framework::Tensor* output,
111111
const std::string pooltype) {
112-
//Create pointers to input and output data
113-
auto* in_data = input.data<T>();
114-
auto* out_data = output->data<T>();
115-
116-
//Calculate the size of each item in sequence
117-
int64_t item_size = input.numel() / input.dims()[0];
118-
auto lod = input.lod()[0];
119-
int seq_num = static_cast<int>(lod.size()) - 1;
120-
if (pooltype == "LAST"){
121-
for (int i=0; i < seq_num; ++i ){
122-
//Calculate the length of each sequence
123-
int64_t seq_len = static_cast<int64_t>(lod[i + 1] - lod[i]);
124-
//Point to the begin of next sequence
125-
in_data += seq_len* item_size;
126-
//Copy the last item to output
127-
std::memcpy(out_data,(in_data-item_size),item_size*sizeof(T));
128-
out_data += item_size;
129-
}
130-
}
131-
else if(pooltype == "FIRST"){
132-
for (int i=0; i < seq_num; ++i ){
133-
//Calculate the length of each sequence
134-
int64_t seq_len = static_cast<int64_t>(lod[i + 1] - lod[i]);
135-
//Copy the first item of sequence to output
136-
std::memcpy(out_data,in_data,item_size*sizeof(T));
137-
//Point to the next sequence
138-
in_data += seq_len * item_size;
139-
out_data += item_size;
140-
}
141-
}
142-
else {
143-
PADDLE_THROW("it's not LAST or FIRST pool type");
112+
// Create pointers to input and output data
113+
auto* in_data = input.data<T>();
114+
auto* out_data = output->data<T>();
115+
116+
// Calculate the size of each item in sequence
117+
int64_t item_size = input.numel() / input.dims()[0];
118+
auto lod = input.lod()[0];
119+
int seq_num = static_cast<int>(lod.size()) - 1;
120+
if (pooltype == "LAST") {
121+
for (int i = 0; i < seq_num; ++i) {
122+
// Calculate the length of each sequence
123+
int64_t seq_len = static_cast<int64_t>(lod[i + 1] - lod[i]);
124+
// Point to the begin of next sequence
125+
in_data += seq_len * item_size;
126+
// Copy the last item of sequence to output
127+
std::memcpy(out_data, (in_data - item_size), item_size * sizeof(T));
128+
out_data += item_size;
144129
}
145-
}
130+
} else if (pooltype == "FIRST") {
131+
for (int i = 0; i < seq_num; ++i) {
132+
// Calculate the length of each sequence
133+
int64_t seq_len = static_cast<int64_t>(lod[i + 1] - lod[i]);
134+
// Copy the first item of sequence to output
135+
std::memcpy(out_data, in_data, item_size * sizeof(T));
136+
// Point to the next sequence
137+
in_data += seq_len * item_size;
138+
out_data += item_size;
139+
}
140+
} else {
141+
PADDLE_THROW("it's not LAST or FIRST pool type");
142+
}
143+
}
146144
};
147145

148146
template <typename T>

0 commit comments

Comments
 (0)