Skip to content

Commit 561d634

Browse files
authored
Merge pull request #4061 from pkuyym/fix-4029
Add expand operator
2 parents 58b4c9a + d7e7a1d commit 561d634

File tree

4 files changed

+428
-0
lines changed

4 files changed

+428
-0
lines changed

paddle/operators/expand_op.cc

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/operators/expand_op.h"
16+
17+
namespace paddle {
18+
namespace operators {
19+
20+
using framework::Tensor;
21+
22+
class ExpandOp : public framework::OperatorWithKernel {
23+
public:
24+
using framework::OperatorWithKernel::OperatorWithKernel;
25+
26+
protected:
27+
void InferShape(framework::InferShapeContext* ctx) const override {
28+
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null.");
29+
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null.");
30+
31+
std::vector<int> expand_times =
32+
ctx->Attrs().Get<std::vector<int>>("expand_times");
33+
auto x_dims = ctx->GetInputDim("X");
34+
35+
PADDLE_ENFORCE_EQ(static_cast<size_t>(x_dims.size()), expand_times.size(),
36+
"The number of Attr(expand_times)'s value must be equal "
37+
"to the rank of Input(X).");
38+
PADDLE_ENFORCE_LE(x_dims.size(), 6,
39+
"The rank of Input(X) must not be greater than 6.");
40+
41+
std::vector<int64_t> out_shape(x_dims.size());
42+
for (size_t i = 0; i < expand_times.size(); ++i) {
43+
PADDLE_ENFORCE_GE(expand_times[i], 1,
44+
"Each value of Attr(expand_times) should not be "
45+
"less than 1.");
46+
out_shape[i] = x_dims[i] * expand_times[i];
47+
}
48+
49+
ctx->SetOutputDim("Out", framework::make_ddim(out_shape));
50+
if (out_shape[0] == x_dims[0]) {
51+
ctx->ShareLoD("X", "Out");
52+
}
53+
}
54+
};
55+
56+
class ExpandOpMaker : public framework::OpProtoAndCheckerMaker {
57+
public:
58+
ExpandOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
59+
: OpProtoAndCheckerMaker(proto, op_checker) {
60+
AddInput("X",
61+
"(Tensor, default Tensor<float>) A tensor with rank in [1, 6]."
62+
"X is the input tensor to be expanded.");
63+
AddOutput("Out",
64+
"(Tensor, default Tensor<float>) A tensor with rank in [1, 6]."
65+
"The rank of Output(Out) is same as Input(X) except that each "
66+
"dimension size of Output(Out) is equal to corresponding "
67+
"dimension size of Input(X) multiplying corresponding value of "
68+
"Attr(expand_times).");
69+
AddAttr<std::vector<int>>("expand_times",
70+
"Expand times number for each dimension.");
71+
AddComment(R"DOC(
72+
Expand operator tiles the input by given times number. You should set times
73+
number for each dimension by providing attribute 'expand_times'. The rank of X
74+
should be in [1, 6]. Please notice that size of 'expand_times' must be same with
75+
X's rank. Following is a using case:
76+
77+
Input(X) is a 3-D tensor with shape [2, 3, 1]:
78+
79+
[
80+
[[1], [2], [3]],
81+
[[4], [5], [6]]
82+
]
83+
84+
Attr(expand_times): [1, 2, 2]
85+
86+
Output(Out) is a 3-D tensor with shape [2, 6, 2]:
87+
88+
[
89+
[[1, 1], [2, 2], [3, 3], [1, 1], [2, 2], [3, 3]],
90+
[[4, 4], [5, 5], [6, 6], [4, 4], [5, 5], [6, 6]]
91+
]
92+
93+
)DOC");
94+
}
95+
};
96+
97+
class ExpandGradOp : public framework::OperatorWithKernel {
98+
public:
99+
using framework::OperatorWithKernel::OperatorWithKernel;
100+
101+
protected:
102+
void InferShape(framework::InferShapeContext* ctx) const override {
103+
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null.");
104+
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
105+
"Input(Out@GRAD) should not be null.");
106+
107+
auto x_dims = ctx->GetInputDim("X");
108+
std::vector<int> expand_times =
109+
ctx->Attrs().Get<std::vector<int>>("expand_times");
110+
auto out_dims = ctx->GetInputDim(framework::GradVarName("Out"));
111+
112+
for (size_t i = 0; i < expand_times.size(); ++i) {
113+
PADDLE_ENFORCE_EQ(x_dims[i] * expand_times[i], out_dims[i],
114+
"Each dimension size of Input(Out@GRAD) should be "
115+
"equal to multiplication of crroresponding dimension "
116+
"size of Input(X) and Attr(expand_times) value.");
117+
}
118+
119+
auto x_grad_name = framework::GradVarName("X");
120+
121+
if (ctx->HasOutput(x_grad_name)) {
122+
ctx->SetOutputDim(x_grad_name, x_dims);
123+
}
124+
}
125+
};
126+
127+
} // namespace operators
128+
} // namespace paddle
129+
130+
namespace ops = paddle::operators;
131+
REGISTER_OP(expand, ops::ExpandOp, ops::ExpandOpMaker, expand_grad,
132+
ops::ExpandGradOp);
133+
REGISTER_OP_CPU_KERNEL(expand,
134+
ops::ExpandKernel<paddle::platform::CPUPlace, float>);
135+
REGISTER_OP_CPU_KERNEL(
136+
expand_grad, ops::ExpandGradKernel<paddle::platform::CPUPlace, float>);

paddle/operators/expand_op.cu

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#define EIGEN_USE_GPU
16+
17+
#include "paddle/operators/expand_op.h"
18+
19+
namespace ops = paddle::operators;
20+
REGISTER_OP_GPU_KERNEL(expand,
21+
ops::ExpandKernel<paddle::platform::GPUPlace, float>);
22+
REGISTER_OP_GPU_KERNEL(
23+
expand_grad, ops::ExpandGradKernel<paddle::platform::GPUPlace, float>);

paddle/operators/expand_op.h

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
You may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include <boost/preprocessor/arithmetic/div.hpp>
18+
#include <boost/preprocessor/arithmetic/mod.hpp>
19+
#include <boost/preprocessor/comparison/greater.hpp>
20+
#include <boost/preprocessor/comparison/greater_equal.hpp>
21+
#include <boost/preprocessor/control/if.hpp>
22+
#include <boost/preprocessor/repetition/repeat.hpp>
23+
#include <iostream>
24+
#include "paddle/framework/eigen.h"
25+
#include "paddle/framework/op_registry.h"
26+
#include "paddle/framework/operator.h"
27+
28+
#define MAX_RANK_SUPPORTED 6
29+
30+
#define EXPAND_TEMPLATE(z, n, data) \
31+
case n + 1: { \
32+
Expand<n + 1>(context); \
33+
break; \
34+
}
35+
#define REP_EXPAND_TEMPLATE(n) BOOST_PP_REPEAT(n, EXPAND_TEMPLATE, ~)
36+
#define COND(n) \
37+
BOOST_PP_GREATER_EQUAL(BOOST_PP_DIV(n, MAX_RANK_SUPPORTED), \
38+
BOOST_PP_MOD(n, MAX_RANK_SUPPORTED))
39+
#define EXPAND_GRAD_CASE(n) \
40+
case n: { \
41+
ExpandBackward<n>(context, reshape_dims_vec, reduce_dims_vec); \
42+
break; \
43+
}
44+
#define EXPAND_GRAD_TEMPLATE(z, n, data) \
45+
BOOST_PP_IF(COND(n), EXPAND_GRAD_CASE(n), )
46+
#define REP_EXPAND_GRAD_TEMPLATE(n) BOOST_PP_REPEAT(n, EXPAND_GRAD_TEMPLATE, ~)
47+
48+
namespace paddle {
49+
namespace operators {
50+
51+
using Tensor = framework::Tensor;
52+
template <typename T, int MajorType = Eigen::RowMajor,
53+
typename IndexType = Eigen::DenseIndex>
54+
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
55+
template <typename T, size_t D, int MajorType = Eigen::RowMajor,
56+
typename IndexType = Eigen::DenseIndex>
57+
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
58+
59+
template <typename Place, typename T>
60+
class ExpandKernel : public framework::OpKernel<T> {
61+
public:
62+
void Compute(const framework::ExecutionContext& context) const override {
63+
auto rank = context.Input<Tensor>("X")->dims().size();
64+
switch (rank) {
65+
REP_EXPAND_TEMPLATE(MAX_RANK_SUPPORTED)
66+
default:
67+
PADDLE_ENFORCE(false,
68+
"Only support tensor with rank being between 1 and 6.");
69+
}
70+
}
71+
72+
protected:
73+
template <int Rank>
74+
void Expand(const framework::ExecutionContext& context) const {
75+
auto* in0 = context.Input<Tensor>("X");
76+
auto& expand_times = context.Attr<std::vector<int>>("expand_times");
77+
auto* out0 = context.Output<Tensor>("Out");
78+
Eigen::DSizes<int, Rank> bcast_dims;
79+
auto x_dims = in0->dims();
80+
for (size_t i = 0; i < expand_times.size(); ++i) {
81+
bcast_dims[i] = expand_times[i];
82+
}
83+
auto x = EigenTensor<T, Rank>::From(*in0);
84+
out0->mutable_data<T>(context.GetPlace());
85+
auto y = EigenTensor<T, Rank>::From(*out0);
86+
auto place = context.GetEigenDevice<Place>();
87+
y.device(place) = x.broadcast(bcast_dims);
88+
}
89+
};
90+
91+
template <typename Place, typename T>
92+
class ExpandGradKernel : public framework::OpKernel<T> {
93+
public:
94+
void Compute(const framework::ExecutionContext& context) const override {
95+
auto* in0 = context.Input<Tensor>("X");
96+
auto& expand_times = context.Attr<std::vector<int>>("expand_times");
97+
auto x_dims = in0->dims();
98+
// 1. reshape_dims_vec is the broadcast parameter. For each dimension i,
99+
// if expand_times[i] > 1 and x_dims[i] > 1, i will be splitted to two
100+
// dimensions [expand_times[i], x_dims[i]].
101+
// 2. reduce_dims_vec is the dimension parameter to compute gradients. For
102+
// each dimension expanded, the gradients should be summed to original
103+
// size.
104+
std::vector<int> reshape_dims_vec;
105+
std::vector<int> reduce_dims_vec;
106+
for (size_t i = 0; i < expand_times.size(); ++i) {
107+
if (expand_times[i] == 1) {
108+
reshape_dims_vec.push_back(x_dims[i]);
109+
} else {
110+
if (x_dims[i] == 1) {
111+
reduce_dims_vec.push_back(reshape_dims_vec.size());
112+
reshape_dims_vec.push_back(expand_times[i]);
113+
} else {
114+
reduce_dims_vec.push_back(reshape_dims_vec.size());
115+
reshape_dims_vec.push_back(expand_times[i]);
116+
reshape_dims_vec.push_back(x_dims[i]);
117+
}
118+
}
119+
}
120+
121+
int dims = reshape_dims_vec.size() * MAX_RANK_SUPPORTED +
122+
reduce_dims_vec.size() - MAX_RANK_SUPPORTED - 1;
123+
// no need reduce, just copy
124+
if (reduce_dims_vec.size() == 0) {
125+
auto* in0 = context.Input<Tensor>(framework::GradVarName("Out"));
126+
auto* out0 = context.Output<Tensor>(framework::GradVarName("X"));
127+
out0->mutable_data<T>(context.GetPlace());
128+
out0->CopyFrom(*in0, context.GetPlace(), context.device_context());
129+
} else {
130+
switch (dims) {
131+
REP_EXPAND_GRAD_TEMPLATE(72)
132+
default:
133+
PADDLE_ENFORCE(
134+
false, "Only support tensor with rank being between 1 and 6.");
135+
}
136+
}
137+
}
138+
139+
protected:
140+
template <int Dims>
141+
void ExpandBackward(const framework::ExecutionContext& context,
142+
const std::vector<int>& reshape_dims_vec,
143+
const std::vector<int>& reduce_dims_vec) const {
144+
size_t reshape_size = Dims / MAX_RANK_SUPPORTED + 1;
145+
size_t reduce_size = Dims % MAX_RANK_SUPPORTED + 1;
146+
PADDLE_ENFORCE_EQ(reshape_size, reshape_dims_vec.size(),
147+
"Inconsistent size between template Dims and "
148+
"reshape dimensions.");
149+
PADDLE_ENFORCE_EQ(reduce_size, reduce_dims_vec.size(),
150+
"Inconsistent size between template Dims and "
151+
"reduce dimensions.");
152+
auto* in0 = context.Input<Tensor>(framework::GradVarName("Out"));
153+
auto* out0 = context.Output<Tensor>(framework::GradVarName("X"));
154+
auto x = EigenVector<T>::Flatten(*(context.Input<Tensor>("X")));
155+
out0->mutable_data<T>(context.GetPlace());
156+
auto x_grad = EigenVector<T>::Flatten(*out0);
157+
Eigen::DSizes<int, Dims / MAX_RANK_SUPPORTED + 1> reshape_dims;
158+
for (size_t i = 0; i < reshape_size; ++i) {
159+
reshape_dims[i] = reshape_dims_vec[i];
160+
}
161+
Eigen::DSizes<int, Dims % MAX_RANK_SUPPORTED + 1> reduce_dims;
162+
for (size_t i = 0; i < reduce_size; ++i) {
163+
reduce_dims[i] = reduce_dims_vec[i];
164+
}
165+
auto out_grad = EigenVector<T>::Flatten(*in0);
166+
x_grad.device(context.GetEigenDevice<Place>()) =
167+
out_grad.reshape(reshape_dims).sum(reduce_dims).reshape(x.dimensions());
168+
}
169+
};
170+
171+
} // namespace operators
172+
} // namespace paddle

0 commit comments

Comments
 (0)