Skip to content

Commit d04c853

Browse files
committed
Refine .cc and .h, more unit test more readable.
1 parent 0d9ba3d commit d04c853

File tree

3 files changed

+46
-32
lines changed

3 files changed

+46
-32
lines changed

paddle/operators/expand_op.cc

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,27 +25,31 @@ class ExpandOp : public framework::OperatorWithKernel {
2525

2626
protected:
2727
void InferShape(framework::InferShapeContext* ctx) const override {
28-
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must be initialized.");
28+
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null.");
29+
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null.");
30+
2931
std::vector<int> expand_times =
30-
ctx->Attrs().Get<std::vector<int>>("expandTimes");
32+
ctx->Attrs().Get<std::vector<int>>("expand_times");
3133
auto x_dims = ctx->GetInputDim("X");
3234

3335
PADDLE_ENFORCE_EQ(static_cast<size_t>(x_dims.size()), expand_times.size(),
34-
"The number of Attr(expandTimes)'s value must be equal "
36+
"The number of Attr(expand_times)'s value must be equal "
3537
"to the rank of Input(X).");
3638
PADDLE_ENFORCE_LE(x_dims.size(), 6,
3739
"The rank of Input(X) must not be greater than 6.");
3840

3941
std::vector<int64_t> out_shape(x_dims.size());
4042
for (size_t i = 0; i < expand_times.size(); ++i) {
4143
PADDLE_ENFORCE_GE(expand_times[i], 1,
42-
"Each value of Attr(expandTimes) should not be "
44+
"Each value of Attr(expand_times) should not be "
4345
"less than 1.");
4446
out_shape[i] = x_dims[i] * expand_times[i];
4547
}
4648

4749
ctx->SetOutputDim("Out", framework::make_ddim(out_shape));
48-
ctx->ShareLoD("X", "Out");
50+
if (out_shape[0] == x_dims[0]) {
51+
ctx->ShareLoD("X", "Out");
52+
}
4953
}
5054
};
5155

@@ -61,13 +65,13 @@ class ExpandOpMaker : public framework::OpProtoAndCheckerMaker {
6165
"The rank of Output(Out) is same as Input(X) except that each "
6266
"dimension size of Output(Out) is equal to corresponding "
6367
"dimension size of Input(X) multiplying corresponding value of "
64-
"Attr(expandTimes).");
65-
AddAttr<std::vector<int>>("expandTimes",
68+
"Attr(expand_times).");
69+
AddAttr<std::vector<int>>("expand_times",
6670
"Expand times number for each dimension.");
6771
AddComment(R"DOC(
6872
Expand operator tiles the input by given times number. You should set times
69-
number for each dimension by providing attribute 'expandTimes'. The rank of X
70-
should be in [1, 6]. Please notice that size of 'expandTimes' must be same with
73+
number for each dimension by providing attribute 'expand_times'. The rank of X
74+
should be in [1, 6]. Please notice that size of 'expand_times' must be same with
7175
X's rank.
7276
)DOC");
7377
}
@@ -82,16 +86,17 @@ class ExpandGradOp : public framework::OperatorWithKernel {
8286
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null.");
8387
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
8488
"Input(Out@GRAD) should not be null.");
89+
8590
auto x_dims = ctx->GetInputDim("X");
8691
std::vector<int> expand_times =
87-
ctx->Attrs().Get<std::vector<int>>("expandTimes");
92+
ctx->Attrs().Get<std::vector<int>>("expand_times");
8893
auto out_dims = ctx->GetInputDim(framework::GradVarName("Out"));
8994

9095
for (size_t i = 0; i < expand_times.size(); ++i) {
9196
PADDLE_ENFORCE_EQ(x_dims[i] * expand_times[i], out_dims[i],
9297
"Each dimension size of Input(Out@GRAD) should be "
9398
"equal to multiplication of crroresponding dimension "
94-
"size of Input(X) and Attr(expandTimes) value.");
99+
"size of Input(X) and Attr(expand_times) value.");
95100
}
96101

97102
auto x_grad_name = framework::GradVarName("X");

paddle/operators/expand_op.h

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,17 @@
2525
#include "paddle/framework/op_registry.h"
2626
#include "paddle/framework/operator.h"
2727

28+
#define MAX_RANK_SUPPORTED 6
29+
2830
#define EXPAND_TEMPLATE(z, n, data) \
2931
case n + 1: { \
3032
Expand<n + 1>(context); \
3133
break; \
3234
}
3335
#define REP_EXPAND_TEMPLATE(n) BOOST_PP_REPEAT(n, EXPAND_TEMPLATE, ~)
34-
35-
#define COND(n) BOOST_PP_GREATER_EQUAL(BOOST_PP_DIV(n, 6), BOOST_PP_MOD(n, 6))
36+
#define COND(n) \
37+
BOOST_PP_GREATER_EQUAL(BOOST_PP_DIV(n, MAX_RANK_SUPPORTED), \
38+
BOOST_PP_MOD(n, MAX_RANK_SUPPORTED))
3639
#define EXPAND_GRAD_CASE(n) \
3740
case n: { \
3841
ExpandBackward<n>(context, reshape_dims_vec, reduce_dims_vec); \
@@ -46,7 +49,6 @@ namespace paddle {
4649
namespace operators {
4750

4851
using Tensor = framework::Tensor;
49-
5052
template <typename T, int MajorType = Eigen::RowMajor,
5153
typename IndexType = Eigen::DenseIndex>
5254
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
@@ -60,7 +62,7 @@ class ExpandKernel : public framework::OpKernel<T> {
6062
void Compute(const framework::ExecutionContext& context) const override {
6163
auto rank = context.Input<Tensor>("X")->dims().size();
6264
switch (rank) {
63-
REP_EXPAND_TEMPLATE(6)
65+
REP_EXPAND_TEMPLATE(MAX_RANK_SUPPORTED)
6466
default:
6567
PADDLE_ENFORCE(false,
6668
"Only support tensor with rank being between 1 and 6.");
@@ -71,7 +73,7 @@ class ExpandKernel : public framework::OpKernel<T> {
7173
template <int Rank>
7274
void Expand(const framework::ExecutionContext& context) const {
7375
auto* in0 = context.Input<Tensor>("X");
74-
auto& expand_times = context.Attr<std::vector<int>>("expandTimes");
76+
auto& expand_times = context.Attr<std::vector<int>>("expand_times");
7577
auto* out0 = context.Output<Tensor>("Out");
7678
Eigen::DSizes<int, Rank> bcast_dims;
7779
auto x_dims = in0->dims();
@@ -91,8 +93,14 @@ class ExpandGradKernel : public framework::OpKernel<T> {
9193
public:
9294
void Compute(const framework::ExecutionContext& context) const override {
9395
auto* in0 = context.Input<Tensor>("X");
94-
auto& expand_times = context.Attr<std::vector<int>>("expandTimes");
96+
auto& expand_times = context.Attr<std::vector<int>>("expand_times");
9597
auto x_dims = in0->dims();
98+
// 1. reshape_dims_vec is the broadcast parameter. For each dimension i,
99+
// if expand_times[i] > 1 and x_dims[i] > 1, i will be splitted to two
100+
// dimensions [expand_times[i], x_dims[i]].
101+
// 2. reduce_dims_vec is the dimension parameter to compute gradients. For
102+
// each dimension expanded, the gradients should be summed to original
103+
// size.
96104
std::vector<int> reshape_dims_vec;
97105
std::vector<int> reduce_dims_vec;
98106
for (size_t i = 0; i < expand_times.size(); ++i) {
@@ -110,7 +118,8 @@ class ExpandGradKernel : public framework::OpKernel<T> {
110118
}
111119
}
112120

113-
int dims = reshape_dims_vec.size() * 6 + reduce_dims_vec.size() - 7;
121+
int dims = reshape_dims_vec.size() * MAX_RANK_SUPPORTED +
122+
reduce_dims_vec.size() - MAX_RANK_SUPPORTED - 1;
114123
// no need reduce, just copy
115124
if (reduce_dims_vec.size() == 0) {
116125
auto* in0 = context.Input<Tensor>(framework::GradVarName("Out"));
@@ -132,8 +141,8 @@ class ExpandGradKernel : public framework::OpKernel<T> {
132141
void ExpandBackward(const framework::ExecutionContext& context,
133142
const std::vector<int>& reshape_dims_vec,
134143
const std::vector<int>& reduce_dims_vec) const {
135-
size_t reshape_size = Dims / 6 + 1;
136-
size_t reduce_size = Dims % 6 + 1;
144+
size_t reshape_size = Dims / MAX_RANK_SUPPORTED + 1;
145+
size_t reduce_size = Dims % MAX_RANK_SUPPORTED + 1;
137146
PADDLE_ENFORCE_EQ(reshape_size, reshape_dims_vec.size(),
138147
"Inconsistent size between template Dims and "
139148
"reshape dimensions.");
@@ -145,11 +154,11 @@ class ExpandGradKernel : public framework::OpKernel<T> {
145154
auto x = EigenVector<T>::Flatten(*(context.Input<Tensor>("X")));
146155
out0->mutable_data<T>(context.GetPlace());
147156
auto x_grad = EigenVector<T>::Flatten(*out0);
148-
Eigen::DSizes<int, Dims / 6 + 1> reshape_dims;
157+
Eigen::DSizes<int, Dims / MAX_RANK_SUPPORTED + 1> reshape_dims;
149158
for (size_t i = 0; i < reshape_size; ++i) {
150159
reshape_dims[i] = reshape_dims_vec[i];
151160
}
152-
Eigen::DSizes<int, Dims % 6 + 1> reduce_dims;
161+
Eigen::DSizes<int, Dims % MAX_RANK_SUPPORTED + 1> reduce_dims;
153162
for (size_t i = 0; i < reduce_size; ++i) {
154163
reduce_dims[i] = reduce_dims_vec[i];
155164
}

python/paddle/v2/framework/tests/test_expand_op.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ class TestExpandOpRank1(OpTest):
77
def setUp(self):
88
self.op_type = "expand"
99
self.inputs = {'X': np.random.random(12).astype("float32")}
10-
self.attrs = {'expandTimes': [2]}
10+
self.attrs = {'expand_times': [2]}
1111
output = np.tile(self.inputs['X'], 2)
1212
self.outputs = {'Out': output}
1313

@@ -18,11 +18,11 @@ def test_check_grad(self):
1818
self.check_grad(['X'], 'Out')
1919

2020

21-
class TestExpandOpRank2_1(OpTest):
21+
class TestExpandOpRank2_Corner(OpTest):
2222
def setUp(self):
2323
self.op_type = "expand"
2424
self.inputs = {'X': np.random.random((12, 14)).astype("float32")}
25-
self.attrs = {'expandTimes': [1, 1]}
25+
self.attrs = {'expand_times': [1, 1]}
2626
output = np.tile(self.inputs['X'], (1, 1))
2727
self.outputs = {'Out': output}
2828

@@ -33,11 +33,11 @@ def test_check_grad(self):
3333
self.check_grad(['X'], 'Out')
3434

3535

36-
class TestExpandOpRank2_2(OpTest):
36+
class TestExpandOpRank2(OpTest):
3737
def setUp(self):
3838
self.op_type = "expand"
3939
self.inputs = {'X': np.random.random((12, 14)).astype("float32")}
40-
self.attrs = {'expandTimes': [2, 3]}
40+
self.attrs = {'expand_times': [2, 3]}
4141
output = np.tile(self.inputs['X'], (2, 3))
4242
self.outputs = {'Out': output}
4343

@@ -48,11 +48,11 @@ def test_check_grad(self):
4848
self.check_grad(['X'], 'Out')
4949

5050

51-
class TestExpandOpRank3_1(OpTest):
51+
class TestExpandOpRank3_Corner(OpTest):
5252
def setUp(self):
5353
self.op_type = "expand"
5454
self.inputs = {'X': np.random.random((2, 4, 5)).astype("float32")}
55-
self.attrs = {'expandTimes': [1, 1, 1]}
55+
self.attrs = {'expand_times': [1, 1, 1]}
5656
output = np.tile(self.inputs['X'], (1, 1, 1))
5757
self.outputs = {'Out': output}
5858

@@ -63,11 +63,11 @@ def test_check_grad(self):
6363
self.check_grad(['X'], 'Out')
6464

6565

66-
class TestExpandOpRank3_2(OpTest):
66+
class TestExpandOpRank3(OpTest):
6767
def setUp(self):
6868
self.op_type = "expand"
6969
self.inputs = {'X': np.random.random((2, 4, 5)).astype("float32")}
70-
self.attrs = {'expandTimes': [2, 1, 4]}
70+
self.attrs = {'expand_times': [2, 1, 4]}
7171
output = np.tile(self.inputs['X'], (2, 1, 4))
7272
self.outputs = {'Out': output}
7373

@@ -82,7 +82,7 @@ class TestExpandOpRank4(OpTest):
8282
def setUp(self):
8383
self.op_type = "expand"
8484
self.inputs = {'X': np.random.random((2, 4, 5, 7)).astype("float32")}
85-
self.attrs = {'expandTimes': [3, 2, 1, 2]}
85+
self.attrs = {'expand_times': [3, 2, 1, 2]}
8686
output = np.tile(self.inputs['X'], (3, 2, 1, 2))
8787
self.outputs = {'Out': output}
8888

0 commit comments

Comments
 (0)