Skip to content

Commit 3e1050a

Browse files
author
chengduo
authored
Add pad_constant_like_op (#12943)
* Add pad_constant_batch_size_like * refine pad_op * optimize memory
1 parent d361624 commit 3e1050a

File tree

6 files changed

+529
-93
lines changed

6 files changed

+529
-93
lines changed

paddle/fluid/operators/math/padding.h

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
#include <utility>
17+
#include <vector>
18+
#include "paddle/fluid/framework/tensor.h"
19+
20+
namespace paddle {
21+
namespace operators {
22+
namespace math {
23+
24+
template <typename T, size_t D, int MajorType = Eigen::RowMajor,
25+
typename IndexType = Eigen::DenseIndex>
26+
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
27+
28+
template <typename DeviceContext, typename T, size_t D>
29+
void PadFunction(const framework::ExecutionContext& context,
30+
const std::vector<int>& pads, const framework::Tensor& src,
31+
T pad_value, framework::Tensor* out) {
32+
Eigen::array<std::pair<int, int>, D> paddings;
33+
34+
for (size_t i = 0; i < paddings.size(); ++i) {
35+
paddings[i].first = pads[i * 2];
36+
paddings[i].second = pads[i * 2 + 1];
37+
}
38+
39+
auto src_tensor = EigenTensor<T, D>::From(src);
40+
auto out_tensor = EigenTensor<T, D>::From(*out);
41+
42+
auto& place =
43+
*context.template device_context<DeviceContext>().eigen_device();
44+
out_tensor.device(place) = src_tensor.pad(paddings, pad_value);
45+
}
46+
47+
template <typename DeviceContext, typename T, size_t D>
48+
void PadGradFunction(const framework::ExecutionContext& context,
49+
const std::vector<int>& pads, const framework::Tensor& src,
50+
framework::Tensor* d_out) {
51+
Eigen::array<std::pair<int, int>, D> paddings;
52+
for (size_t i = 0; i < paddings.size(); ++i) {
53+
paddings[i].first = -pads[i * 2];
54+
paddings[i].second = -pads[i * 2 + 1];
55+
}
56+
57+
auto d_out_tensor = EigenTensor<T, D>::From(*d_out);
58+
auto src_tensor = EigenTensor<T, D>::From(src);
59+
auto& place =
60+
*context.template device_context<DeviceContext>().eigen_device();
61+
d_out_tensor.device(place) = src_tensor.pad(paddings, 0);
62+
}
63+
64+
template <typename DeviceContext, typename T>
65+
void PaddingFunctor(int rank, const framework::ExecutionContext& context,
66+
const std::vector<int>& pads, T pad_value,
67+
const framework::Tensor& src, framework::Tensor* out) {
68+
switch (rank) {
69+
case 1:
70+
PadFunction<DeviceContext, T, 1>(context, pads, src, pad_value, out);
71+
break;
72+
case 2:
73+
PadFunction<DeviceContext, T, 2>(context, pads, src, pad_value, out);
74+
break;
75+
case 3:
76+
PadFunction<DeviceContext, T, 3>(context, pads, src, pad_value, out);
77+
break;
78+
case 4:
79+
PadFunction<DeviceContext, T, 4>(context, pads, src, pad_value, out);
80+
break;
81+
case 5:
82+
PadFunction<DeviceContext, T, 5>(context, pads, src, pad_value, out);
83+
break;
84+
case 6:
85+
PadFunction<DeviceContext, T, 6>(context, pads, src, pad_value, out);
86+
break;
87+
default:
88+
PADDLE_THROW(
89+
"PadOp only support tensors with no more than 6 dimensions.");
90+
}
91+
}
92+
93+
template <typename DeviceContext, typename T>
94+
void PaddingGradFunctor(int rank, const framework::ExecutionContext& context,
95+
const std::vector<int>& pads,
96+
const framework::Tensor& src, framework::Tensor* out) {
97+
switch (rank) {
98+
case 1:
99+
PadGradFunction<DeviceContext, T, 1>(context, pads, src, out);
100+
break;
101+
case 2:
102+
PadGradFunction<DeviceContext, T, 2>(context, pads, src, out);
103+
break;
104+
case 3:
105+
PadGradFunction<DeviceContext, T, 3>(context, pads, src, out);
106+
break;
107+
case 4:
108+
PadGradFunction<DeviceContext, T, 4>(context, pads, src, out);
109+
break;
110+
case 5:
111+
PadGradFunction<DeviceContext, T, 5>(context, pads, src, out);
112+
break;
113+
case 6:
114+
PadGradFunction<DeviceContext, T, 6>(context, pads, src, out);
115+
break;
116+
default:
117+
PADDLE_THROW(
118+
"PadOp only support tensors with no more than 6 dimensions.");
119+
}
120+
}
121+
122+
} // namespace math
123+
} // namespace operators
124+
} // namespace paddle
Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/operators/pad_constant_like_op.h"
16+
17+
namespace paddle {
18+
namespace operators {
19+
20+
using framework::Tensor;
21+
22+
class PadConstantLikeOp : public framework::OperatorWithKernel {
23+
public:
24+
using framework::OperatorWithKernel::OperatorWithKernel;
25+
26+
void InferShape(framework::InferShapeContext *ctx) const override {
27+
PADDLE_ENFORCE(ctx->HasInput("X"),
28+
"Input(X) of PadConstantLikeOp should not be null.");
29+
PADDLE_ENFORCE(ctx->HasInput("Y"),
30+
"Input(Y) of PadConstantLikeOp should not be null.");
31+
PADDLE_ENFORCE(ctx->HasOutput("Out"),
32+
"Output(Out) of PadConstantLikeOp should not be null.");
33+
34+
auto x_dim = ctx->GetInputDim("X");
35+
auto y_dim = ctx->GetInputDim("Y");
36+
37+
PADDLE_ENFORCE_EQ(x_dim.size(), y_dim.size(),
38+
"The dimention of X and Y should be the same.");
39+
40+
for (int i = 0; i < x_dim.size(); ++i) {
41+
PADDLE_ENFORCE_GE(x_dim[i], y_dim[i]);
42+
}
43+
ctx->SetOutputDim("Out", x_dim);
44+
ctx->ShareLoD("X", /*->*/ "Out");
45+
}
46+
};
47+
48+
class PadConstantLikeOpMaker : public framework::OpProtoAndCheckerMaker {
49+
public:
50+
void Make() override {
51+
AddInput("X",
52+
"The input of pad_constant_like op. "
53+
"The input should be a k-D tensor(k > 0 and k < 7)");
54+
AddInput("Y",
55+
"The input of pad_constant_like op. "
56+
"The input should be a k-D tensor(k > 0 and k < 7)");
57+
AddOutput("Out",
58+
"The output of pad_constant_like op. "
59+
"A tensor with the same shape as X.");
60+
AddAttr<float>("pad_value",
61+
"(float, default 0.0) "
62+
"The value to fill the padded areas.")
63+
.SetDefault(0.0f);
64+
AddComment(R"DOC(
65+
PadConstantLikeOp Operator.
66+
67+
Pad input(Y) with a pad_value, the number of values padded to the edges of each
68+
axis is specified by the difference of the shape of X and Y.
69+
((0, shape_x_0 - shape_y_0), … (0, shape_x_n - shape_y_n)) unique pad widths for
70+
each axis.
71+
The input should be a k-D tensor(k > 0 and k < 7). As an example:
72+
73+
case1:
74+
Given:
75+
X = [[1, 2],
76+
[3, 4],
77+
[1, 2],
78+
[3, 4]]],
79+
X.shape = (4, 2)
80+
81+
Y = [[5, 6],
82+
[7, 8]],
83+
Y.shape = (2, 2)
84+
85+
And
86+
pad_value = 0,
87+
88+
Return:
89+
Out = [[5, 6],
90+
[7, 8],
91+
[0, 0],
92+
[0, 0]]
93+
Out.shape = (4, 2)
94+
95+
case2:
96+
Given:
97+
X = [[[[ 0, 1, 2],
98+
[ 3, 4, 5]],
99+
[[ 6, 7, 8],
100+
[ 9, 10, 11]],
101+
[[12, 13, 14],
102+
[15, 16, 17]]],
103+
[[[18, 19, 20],
104+
[21, 22, 23]],
105+
[[24, 25, 26],
106+
[27, 28, 29]],
107+
[[30, 31, 32],
108+
[33, 34, 35]]]]
109+
X.shape = (2, 3, 2, 3)
110+
111+
Y = [[[[35, 36, 37]],
112+
[[38, 39, 40]],
113+
[[41, 42, 43]]]]
114+
Y.shape = (1, 3, 1, 3)
115+
116+
And
117+
pad_value = -1,
118+
119+
Return:
120+
121+
Out = [[[[35, 36, 37],
122+
[-1, -1, -1]],
123+
[[38, 39, 40],
124+
[-1, -1, -1]],
125+
[[41, 42, 43],
126+
[-1, -1, -1]]],
127+
[[[-1, -1, -1],
128+
[-1, -1, -1]],
129+
[[-1, -1, -1],
130+
[-1, -1, -1]],
131+
[[-1, -1, -1],
132+
[-1, -1, -1]]]]
133+
Out.shape = (2, 3, 2, 3)
134+
)DOC");
135+
}
136+
};
137+
138+
class PadConstantLikeOpGrad : public framework::OperatorWithKernel {
139+
public:
140+
using framework::OperatorWithKernel::OperatorWithKernel;
141+
142+
void InferShape(framework::InferShapeContext *ctx) const override {
143+
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null");
144+
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
145+
"Input(Out@GRAD) should not be null");
146+
auto y_dim = ctx->GetInputDim("Y");
147+
auto dout_dim = ctx->GetInputDim(framework::GradVarName("Out"));
148+
149+
PADDLE_ENFORCE_EQ(dout_dim.size(), y_dim.size(),
150+
"The dimention of X and Y should be the same.");
151+
152+
auto y_grad_name = framework::GradVarName("Y");
153+
if (ctx->HasOutput(y_grad_name)) {
154+
ctx->SetOutputDim(y_grad_name, y_dim);
155+
ctx->ShareLoD("Y", /*->*/ y_grad_name);
156+
157+
for (int i = 0; i < y_dim.size(); ++i) {
158+
PADDLE_ENFORCE_GE(dout_dim[i], y_dim[i]);
159+
}
160+
}
161+
}
162+
};
163+
164+
class PadConstantLikeOpGradMaker : public framework::SingleGradOpDescMaker {
165+
public:
166+
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
167+
168+
protected:
169+
std::unique_ptr<framework::OpDesc> Apply() const override {
170+
auto *bind = new framework::OpDesc();
171+
bind->SetType("pad_constant_like_grad");
172+
bind->SetInput("Y", Input("Y"));
173+
bind->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
174+
bind->SetOutput(framework::GradVarName("Y"), InputGrad("Y"));
175+
bind->SetAttrMap(Attrs());
176+
return std::unique_ptr<framework::OpDesc>(bind);
177+
}
178+
};
179+
180+
} // namespace operators
181+
} // namespace paddle
182+
183+
namespace ops = paddle::operators;
184+
185+
REGISTER_OPERATOR(pad_constant_like, ops::PadConstantLikeOp,
186+
ops::PadConstantLikeOpMaker, ops::PadConstantLikeOpGradMaker);
187+
REGISTER_OPERATOR(pad_constant_like_grad, ops::PadConstantLikeOpGrad);
188+
189+
REGISTER_OP_CPU_KERNEL(
190+
pad_constant_like,
191+
ops::PadConstantLikeKernel<paddle::platform::CPUDeviceContext, float>,
192+
ops::PadConstantLikeKernel<paddle::platform::CPUDeviceContext, double>);
193+
REGISTER_OP_CPU_KERNEL(
194+
pad_constant_like_grad,
195+
ops::PadConstantLikeGradKernel<paddle::platform::CPUDeviceContext, float>,
196+
ops::PadConstantLikeGradKernel<paddle::platform::CPUDeviceContext, double>);
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#define EIGEN_USE_GPU
16+
#include "paddle/fluid/operators/pad_constant_like_op.h"
17+
18+
namespace ops = paddle::operators;
19+
REGISTER_OP_CUDA_KERNEL(
20+
pad_constant_like,
21+
ops::PadConstantLikeKernel<paddle::platform::CUDADeviceContext, float>,
22+
ops::PadConstantLikeKernel<paddle::platform::CUDADeviceContext, double>);
23+
REGISTER_OP_CUDA_KERNEL(
24+
pad_constant_like_grad,
25+
ops::PadConstantLikeGradKernel<paddle::platform::CUDADeviceContext, float>,
26+
ops::PadConstantLikeGradKernel<paddle::platform::CUDADeviceContext,
27+
double>);

0 commit comments

Comments
 (0)