Skip to content

Commit 72dd6b3

Browse files
author
chengduo
authored
Add sequence_expand_as_op (#13420)
* Add sequence_expand_as_op * follow comment
1 parent d5455b2 commit 72dd6b3

File tree

6 files changed

+594
-0
lines changed

6 files changed

+594
-0
lines changed

paddle/fluid/API.spec

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ paddle.fluid.layers.beam_search_decode ArgSpec(args=['ids', 'scores', 'beam_size
116116
paddle.fluid.layers.conv2d_transpose ArgSpec(args=['input', 'num_filters', 'output_size', 'filter_size', 'padding', 'stride', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'], varargs=None, keywords=None, defaults=(None, None, 0, 1, 1, None, None, None, True, None, None))
117117
paddle.fluid.layers.conv3d_transpose ArgSpec(args=['input', 'num_filters', 'output_size', 'filter_size', 'padding', 'stride', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'], varargs=None, keywords=None, defaults=(None, None, 0, 1, 1, None, None, None, True, None, None))
118118
paddle.fluid.layers.sequence_expand ArgSpec(args=['x', 'y', 'ref_level', 'name'], varargs=None, keywords=None, defaults=(-1, None))
119+
paddle.fluid.layers.sequence_expand_as ArgSpec(args=['x', 'y', 'name'], varargs=None, keywords=None, defaults=(None,))
119120
paddle.fluid.layers.sequence_pad ArgSpec(args=['x', 'pad_value', 'maxlen'], varargs=None, keywords=None, defaults=(None,))
120121
paddle.fluid.layers.lstm_unit ArgSpec(args=['x_t', 'hidden_t_prev', 'cell_t_prev', 'forget_bias', 'param_attr', 'bias_attr', 'name'], varargs=None, keywords=None, defaults=(0.0, None, None, None))
121122
paddle.fluid.layers.reduce_sum ArgSpec(args=['input', 'dim', 'keep_dim', 'name'], varargs=None, keywords=None, defaults=(None, False, None))
Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/operators/sequence_expand_as_op.h"
16+
17+
namespace paddle {
18+
namespace operators {
19+
20+
using framework::LoDTensor;
21+
22+
class SequenceExpandAsOp : public framework::OperatorWithKernel {
23+
public:
24+
using framework::OperatorWithKernel::OperatorWithKernel;
25+
26+
protected:
27+
void InferShape(framework::InferShapeContext* ctx) const override {
28+
PADDLE_ENFORCE(ctx->HasInput("X"),
29+
"Input(X) of SequenceExpandAsOp should not be null.");
30+
PADDLE_ENFORCE(ctx->HasInput("Y"),
31+
"Input(Y) of SequenceExpandAsOp should not be null.");
32+
PADDLE_ENFORCE(ctx->HasOutput("Out"),
33+
"Output(Out) of SequenceExpandAsOp should not be null.");
34+
35+
auto x_dims = ctx->GetInputDim("X");
36+
auto out_dims = x_dims;
37+
38+
PADDLE_ENFORCE_GE(x_dims.size(), 2,
39+
"Dimension number of Input(X) should be at least 2.");
40+
41+
if (ctx->IsRuntime()) {
42+
framework::Variable* x_var =
43+
boost::get<framework::Variable*>(ctx->GetInputVarPtrs("X")[0]);
44+
framework::Variable* y_var =
45+
boost::get<framework::Variable*>(ctx->GetInputVarPtrs("Y")[0]);
46+
47+
auto& x_dim = x_var->Get<LoDTensor>().dims();
48+
auto& y_lod = y_var->Get<LoDTensor>().lod();
49+
50+
PADDLE_ENFORCE_EQ(y_lod.size(), 1,
51+
"Level number of Input(Y)'s lod should be 1.");
52+
53+
PADDLE_ENFORCE_EQ(static_cast<size_t>(x_dim[0]), y_lod[0].size() - 1,
54+
"The first dimension of Input(X) should be equal "
55+
"to the size of Input(Y)'s 0 level lod.");
56+
57+
int64_t out_first_dim = 0;
58+
if (y_lod[0].size() <= 1) {
59+
out_first_dim = x_dims[0];
60+
} else {
61+
for (size_t i = 1; i < y_lod[0].size(); ++i) {
62+
out_first_dim += (y_lod[0][i] - y_lod[0][i - 1]);
63+
}
64+
}
65+
out_dims[0] = out_first_dim;
66+
} else {
67+
out_dims[0] = -1;
68+
}
69+
70+
ctx->SetOutputDim("Out", out_dims);
71+
ctx->ShareLoD("Y", /*->*/ "Out");
72+
}
73+
};
74+
75+
class SequenceExpandAsOpMaker : public framework::OpProtoAndCheckerMaker {
76+
public:
77+
void Make() override {
78+
AddInput("X",
79+
"(LoDTensor, default LoDTensor<float>) A 2-D LoDTensor whose lod "
80+
"level is at most 1.");
81+
AddInput("Y",
82+
"(LoDTensor, default LoDTensor<float>) Referred LoDTensor whose "
83+
"lod (specified level) is referred by Input(X).");
84+
AddOutput("Out",
85+
"(LodTensor, default LoDTensor<float>) Output LoDTensor which is "
86+
"generated from Input(X) by referring lod of Input(Y).");
87+
AddComment(R"DOC(
88+
Sequence Expand As Operator.
89+
90+
This operator expands `X` according to the zeroth level lod of `Y`. Current
91+
implementation requires the level number of Input(Y)'s lod should be 1, and
92+
the first dimension of Input(X) should be equal to the size of Input(Y)'s zeroth
93+
level lod, and lod of Input(X) is not considered.
94+
95+
Following are cases to better explain how this works:
96+
97+
Case 1:
98+
99+
Given a 1-level LoDTensor input(X)
100+
X.data = [[a], [b], [c], [d]]
101+
X.dims = [4, 1]
102+
and input(Y)
103+
Y.lod = [[0, 3, 6, 7, 8]]
104+
ref_level: 0
105+
then we get 1-level LoDTensor
106+
Out.lod = [[0, 3, 6, 7, 8]]
107+
Out.data = [[a], [a], [a], [b], [b], [b], [c], [d]]
108+
Out.dims = [8, 1]
109+
110+
Case 2:
111+
112+
Given a common Tensor input(X)
113+
X.data = [[a, b], [c, d], [e, f]]
114+
X.dims = [3, 2]
115+
and input(Y)
116+
Y.lod = [[0, 2, 3, 6]]
117+
ref_level: 0
118+
then we get a common LoDTensor
119+
Out.lod = [[0, 2, 3, 6]]
120+
Out.data = [[a, b], [a, b] [c, d], [e, f], [e, f], [e, f]]
121+
Out.dims = [6, 2]
122+
123+
)DOC");
124+
}
125+
};
126+
127+
class SequenceExpandAsOpGrad : public framework::OperatorWithKernel {
128+
public:
129+
using framework::OperatorWithKernel::OperatorWithKernel;
130+
131+
protected:
132+
void InferShape(framework::InferShapeContext* ctx) const override {
133+
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null.");
134+
PADDLE_ENFORCE(ctx->HasInput("Out"), "Input(Out) should not be null.");
135+
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
136+
"Input(Out@GRAD) should not be null.");
137+
138+
auto x_dims = ctx->GetInputDim("X");
139+
auto x_grad_name = framework::GradVarName("X");
140+
141+
if (ctx->HasOutput(x_grad_name)) {
142+
ctx->SetOutputDim(x_grad_name, x_dims);
143+
ctx->ShareLoD("X", x_grad_name);
144+
}
145+
}
146+
};
147+
148+
} // namespace operators
149+
} // namespace paddle
150+
151+
namespace ops = paddle::operators;
152+
REGISTER_OPERATOR(sequence_expand_as, ops::SequenceExpandAsOp,
153+
ops::SequenceExpandAsOpMaker,
154+
paddle::framework::DefaultGradOpDescMaker<true>);
155+
REGISTER_OPERATOR(sequence_expand_as_grad, ops::SequenceExpandAsOpGrad);
156+
REGISTER_OP_CPU_KERNEL(
157+
sequence_expand_as,
158+
ops::SequenceExpandAsKernel<paddle::platform::CPUDeviceContext, float>,
159+
ops::SequenceExpandAsKernel<paddle::platform::CPUDeviceContext, double>,
160+
ops::SequenceExpandAsKernel<paddle::platform::CPUDeviceContext, int>,
161+
ops::SequenceExpandAsKernel<paddle::platform::CPUDeviceContext, int64_t>);
162+
REGISTER_OP_CPU_KERNEL(
163+
sequence_expand_as_grad,
164+
ops::SequenceExpandAsGradKernel<paddle::platform::CPUDeviceContext, float>,
165+
ops::SequenceExpandAsGradKernel<paddle::platform::CPUDeviceContext, double>,
166+
ops::SequenceExpandAsGradKernel<paddle::platform::CPUDeviceContext, int>,
167+
ops::SequenceExpandAsGradKernel<paddle::platform::CPUDeviceContext,
168+
int64_t>);
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include <algorithm>
16+
#include "paddle/fluid/operators/sequence_expand_as_op.h"
17+
#include "paddle/fluid/platform/cuda_primitives.h"
18+
19+
namespace paddle {
20+
namespace operators {
21+
22+
using LoDTensor = framework::LoDTensor;
23+
24+
template <typename T>
25+
static __global__ void sequence_expand_as_kernel(const T *in_data,
26+
const size_t *expand_offset,
27+
const size_t src_hight,
28+
const size_t src_widht,
29+
T *out_data) {
30+
for (int h_id = blockIdx.x; h_id < src_hight; h_id += gridDim.x) {
31+
int span = expand_offset[h_id + 1] - expand_offset[h_id];
32+
if (span == 0) continue;
33+
const T *src = in_data + h_id * src_widht;
34+
for (int w_id = threadIdx.x; w_id < src_widht; w_id += blockDim.x) {
35+
T ele = src[w_id];
36+
int offset = expand_offset[h_id] * src_widht;
37+
for (int k = 0; k < span; ++k) {
38+
out_data[offset + k * src_widht + w_id] = ele;
39+
}
40+
}
41+
}
42+
}
43+
44+
template <typename T>
45+
static __global__ void sequence_expand_as_grad_kernel(
46+
const T *dout_data, const size_t *expand_offset, const size_t dst_hight,
47+
const size_t dst_width, T *dx_data) {
48+
for (int h_id = blockIdx.x; h_id < dst_hight; h_id += gridDim.x) {
49+
T *dst = dx_data + h_id * dst_width;
50+
int span = expand_offset[h_id + 1] - expand_offset[h_id];
51+
52+
for (int w_id = threadIdx.x; w_id < dst_width; w_id += blockDim.x) {
53+
T result = 0;
54+
for (int k = 0; k < span; ++k) {
55+
int offset = (expand_offset[h_id] + k) * dst_width;
56+
const T *src = dout_data + offset;
57+
result += src[w_id];
58+
}
59+
dst[w_id] = result;
60+
}
61+
}
62+
}
63+
64+
template <typename T>
65+
struct SequenceExpandFunctor<platform::CUDADeviceContext, T> {
66+
void operator()(
67+
const platform::CUDADeviceContext &context, const LoDTensor &x,
68+
const framework::Vector<size_t> &ref_lod, /*expand referenced lod*/
69+
LoDTensor *out) {
70+
int hight = x.dims()[0];
71+
int width = framework::product(x.dims()) / hight;
72+
73+
const int kThreadsPerBlock = 1024;
74+
int thread_x = kThreadsPerBlock;
75+
if (width < kThreadsPerBlock) { // block_cols is aligned by 32.
76+
thread_x = ((width + 31) >> 5) << 5;
77+
}
78+
79+
int max_threads = context.GetMaxPhysicalThreadCount();
80+
int block_x = std::max(max_threads / thread_x, 1);
81+
82+
dim3 block_size(thread_x);
83+
dim3 grid_size(block_x);
84+
sequence_expand_as_kernel<<<grid_size, block_size, 0, context.stream()>>>(
85+
x.data<T>(), ref_lod.CUDAData(context.GetPlace()), hight, width,
86+
out->mutable_data<T>(context.GetPlace()));
87+
}
88+
};
89+
90+
template <typename T>
91+
struct SequenceExpandAsGradFunctor<platform::CUDADeviceContext, T> {
92+
void operator()(const platform::CUDADeviceContext &context,
93+
const LoDTensor &dout,
94+
const framework::Vector<size_t> &ref_lod, /*expand based lod*/
95+
LoDTensor *dx) {
96+
int hight = dx->dims()[0];
97+
int width = framework::product(dx->dims()) / hight;
98+
99+
const int kThreadsPerBlock = 1024;
100+
int thread_x = kThreadsPerBlock;
101+
if (width < kThreadsPerBlock) { // block_cols is aligned by 32.
102+
thread_x = ((width + 31) >> 5) << 5;
103+
}
104+
105+
int max_threads = context.GetMaxPhysicalThreadCount();
106+
int block_x = std::max(max_threads / thread_x, 1);
107+
108+
dim3 block_size(thread_x);
109+
dim3 grid_size(block_x);
110+
sequence_expand_as_grad_kernel<<<grid_size, block_size, 0,
111+
context.stream()>>>(
112+
dout.data<T>(), ref_lod.CUDAData(context.GetPlace()), hight, width,
113+
dx->mutable_data<T>(context.GetPlace()));
114+
}
115+
};
116+
117+
} // namespace operators
118+
} // namespace paddle
119+
120+
namespace ops = paddle::operators;
121+
REGISTER_OP_CUDA_KERNEL(
122+
sequence_expand_as,
123+
ops::SequenceExpandAsKernel<paddle::platform::CUDADeviceContext, float>,
124+
ops::SequenceExpandAsKernel<paddle::platform::CUDADeviceContext, double>,
125+
ops::SequenceExpandAsKernel<paddle::platform::CUDADeviceContext, int>,
126+
ops::SequenceExpandAsKernel<paddle::platform::CUDADeviceContext, int64_t>);
127+
REGISTER_OP_CUDA_KERNEL(
128+
sequence_expand_as_grad,
129+
ops::SequenceExpandAsGradKernel<paddle::platform::CUDADeviceContext, float>,
130+
ops::SequenceExpandAsGradKernel<paddle::platform::CUDADeviceContext,
131+
double>,
132+
ops::SequenceExpandAsGradKernel<paddle::platform::CUDADeviceContext, int>,
133+
ops::SequenceExpandAsGradKernel<paddle::platform::CUDADeviceContext,
134+
int64_t>);

0 commit comments

Comments
 (0)