Skip to content

Commit 40c54db

Browse files
authored
Merge pull request #13338 from bingyanghuang/bingyang/seq_pool_memcpy
Use memcpy to rewrite the sequence pooling LAST and FIRST mode
2 parents bdbf1bc + 76553c5 commit 40c54db

File tree

1 file changed

+62
-4
lines changed

1 file changed

+62
-4
lines changed

paddle/fluid/operators/math/sequence_pooling.cc

Lines changed: 62 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,58 @@ class MaxSeqPoolGradFunctor {
103103
}
104104
};
105105

106+
template <typename T>
107+
class LastSeqPoolFunctor {
108+
public:
109+
void operator()(const platform::CPUDeviceContext& context,
110+
const framework::LoDTensor& input,
111+
framework::Tensor* output) {
112+
// Create pointers to input and output data
113+
auto* in_data = input.data<T>();
114+
auto* out_data = output->data<T>();
115+
116+
// Calculate the size of each item in sequence
117+
int64_t item_size = input.numel() / input.dims()[0];
118+
auto lod = input.lod()[0];
119+
int seq_num = static_cast<int>(lod.size()) - 1;
120+
for (int i = 0; i < seq_num; ++i) {
121+
// Calculate the length of each sequence
122+
int64_t seq_len = static_cast<int64_t>(lod[i + 1] - lod[i]);
123+
// Point to the begin of next sequence
124+
in_data += seq_len * item_size;
125+
// Copy the last item of sequence to output
126+
std::memcpy(out_data, (in_data - item_size), item_size * sizeof(T));
127+
out_data += item_size;
128+
}
129+
}
130+
};
131+
132+
template <typename T>
133+
class FirstSeqPoolFunctor {
134+
public:
135+
void operator()(const platform::CPUDeviceContext& context,
136+
const framework::LoDTensor& input,
137+
framework::Tensor* output) {
138+
// Create pointers to input and output data
139+
auto* in_data = input.data<T>();
140+
auto* out_data = output->data<T>();
141+
142+
// Calculate the size of each item in sequence
143+
int64_t item_size = input.numel() / input.dims()[0];
144+
auto lod = input.lod()[0];
145+
int seq_num = static_cast<int>(lod.size()) - 1;
146+
for (int i = 0; i < seq_num; ++i) {
147+
// Calculate the length of each sequence
148+
int64_t seq_len = static_cast<int64_t>(lod[i + 1] - lod[i]);
149+
// Copy the first item of sequence to output
150+
std::memcpy(out_data, in_data, item_size * sizeof(T));
151+
// Point to the next sequence
152+
in_data += seq_len * item_size;
153+
out_data += item_size;
154+
}
155+
}
156+
};
157+
106158
template <typename T>
107159
class SequencePoolFunctor<platform::CPUDeviceContext, T> {
108160
public:
@@ -116,6 +168,16 @@ class SequencePoolFunctor<platform::CPUDeviceContext, T> {
116168
max_pool(context, input, output, index);
117169
return;
118170
}
171+
if (pooltype == "LAST") {
172+
math::LastSeqPoolFunctor<T> last_pool;
173+
last_pool(context, input, output);
174+
return;
175+
}
176+
if (pooltype == "FIRST") {
177+
math::FirstSeqPoolFunctor<T> first_pool;
178+
first_pool(context, input, output);
179+
return;
180+
}
119181
auto lod = input.lod()[0];
120182
auto& place = *context.eigen_device();
121183
for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
@@ -133,10 +195,6 @@ class SequencePoolFunctor<platform::CPUDeviceContext, T> {
133195
} else if (pooltype == "SQRT") {
134196
out_e.device(place) = in_e.sum(Eigen::array<int, 1>({{0}})) /
135197
std::sqrt(static_cast<T>(h));
136-
} else if (pooltype == "LAST") {
137-
out_e.device(place) = in_e.chip(h - 1, 0);
138-
} else if (pooltype == "FIRST") {
139-
out_e.device(place) = in_e.chip(0, 0);
140198
} else {
141199
PADDLE_THROW("unsupported pooling pooltype");
142200
}

0 commit comments

Comments
 (0)