Skip to content

Commit 83394ba

Browse files
committed
modified by luotao's suggestion
1 parent 1454cd5 commit 83394ba

File tree

1 file changed

+44
-28
lines changed

1 file changed

+44
-28
lines changed

paddle/fluid/operators/math/sequence_pooling.cc

Lines changed: 44 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -104,11 +104,10 @@ class MaxSeqPoolGradFunctor {
104104
};
105105

106106
template <typename T>
107-
class LastFirstSeqPoolFunctor {
107+
class LastSeqPoolFunctor {
108108
public:
109109
void operator()(const platform::CPUDeviceContext& context,
110-
const framework::LoDTensor& input, framework::Tensor* output,
111-
const std::string pooltype) {
110+
const framework::LoDTensor& input, framework::Tensor* output) {
112111
// Create pointers to input and output data
113112
auto* in_data = input.data<T>();
114113
auto* out_data = output->data<T>();
@@ -117,29 +116,40 @@ class LastFirstSeqPoolFunctor {
117116
int64_t item_size = input.numel() / input.dims()[0];
118117
auto lod = input.lod()[0];
119118
int seq_num = static_cast<int>(lod.size()) - 1;
120-
if (pooltype == "LAST") {
121-
for (int i = 0; i < seq_num; ++i) {
122-
// Calculate the length of each sequence
123-
int64_t seq_len = static_cast<int64_t>(lod[i + 1] - lod[i]);
124-
// Point to the begin of next sequence
125-
in_data += seq_len * item_size;
126-
// Copy the last item of sequence to output
127-
std::memcpy(out_data, (in_data - item_size), item_size * sizeof(T));
128-
out_data += item_size;
129-
}
130-
} else if (pooltype == "FIRST") {
131-
for (int i = 0; i < seq_num; ++i) {
132-
// Calculate the length of each sequence
133-
int64_t seq_len = static_cast<int64_t>(lod[i + 1] - lod[i]);
134-
// Copy the first item of sequence to output
135-
std::memcpy(out_data, in_data, item_size * sizeof(T));
136-
// Point to the next sequence
137-
in_data += seq_len * item_size;
138-
out_data += item_size;
119+
for (int i = 0; i < seq_num; ++i) {
120+
// Calculate the length of each sequence
121+
int64_t seq_len = static_cast<int64_t>(lod[i + 1] - lod[i]);
122+
// Point to the begin of next sequence
123+
in_data += seq_len * item_size;
124+
// Copy the last item of sequence to output
125+
std::memcpy(out_data, (in_data - item_size), item_size * sizeof(T));
126+
out_data += item_size;
127+
}
128+
}
129+
};
130+
131+
template <typename T>
132+
class FirstSeqPoolFunctor {
133+
public:
134+
void operator()(const platform::CPUDeviceContext& context,
135+
const framework::LoDTensor& input, framework::Tensor* output) {
136+
// Create pointers to input and output data
137+
auto* in_data = input.data<T>();
138+
auto* out_data = output->data<T>();
139+
140+
// Calculate the size of each item in sequence
141+
int64_t item_size = input.numel() / input.dims()[0];
142+
auto lod = input.lod()[0];
143+
int seq_num = static_cast<int>(lod.size()) - 1;
144+
for (int i = 0; i < seq_num; ++i) {
145+
// Calculate the length of each sequence
146+
int64_t seq_len = static_cast<int64_t>(lod[i + 1] - lod[i]);
147+
// Copy the first item of sequence to output
148+
std::memcpy(out_data, in_data, item_size * sizeof(T));
149+
// Point to the next sequence
150+
in_data += seq_len * item_size;
151+
out_data += item_size;
139152
}
140-
} else {
141-
PADDLE_THROW("it's not LAST or FIRST pool type");
142-
}
143153
}
144154
};
145155

@@ -156,11 +166,17 @@ class SequencePoolFunctor<platform::CPUDeviceContext, T> {
156166
max_pool(context, input, output, index);
157167
return;
158168
}
159-
if (pooltype == "LAST" || pooltype == "FIRST") {
160-
math::LastFirstSeqPoolFunctor<T> lastfirst_pool;
161-
lastfirst_pool(context, input, output, pooltype);
169+
if (pooltype == "LAST") {
170+
math::LastSeqPoolFunctor<T> last_pool;
171+
last_pool(context, input, output);
162172
return;
163173
}
174+
if (pooltype == "FIRST") {
175+
math::FirstSeqPoolFunctor<T> first_pool;
176+
first_pool(context, input, output);
177+
return;
178+
}
179+
164180

165181
auto lod = input.lod()[0];
166182
auto& place = *context.eigen_device();

0 commit comments

Comments
 (0)