Skip to content

Commit 0d9ba3d

Browse files
committed
Adapt to new interface.
1 parent 7be390a commit 0d9ba3d

File tree

2 files changed

+55
-56
lines changed

2 files changed

+55
-56
lines changed

paddle/operators/expand_op.cc

Lines changed: 37 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -24,26 +24,28 @@ class ExpandOp : public framework::OperatorWithKernel {
2424
using framework::OperatorWithKernel::OperatorWithKernel;
2525

2626
protected:
27-
void InferShape(const framework::InferShapeContext& ctx) const override {
28-
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "X must be initialized.");
29-
std::vector<int> expand_times = Attr<std::vector<int>>("expandTimes");
30-
auto x_dims = ctx.Input<Tensor>("X")->dims();
31-
32-
PADDLE_ENFORCE_EQ(x_dims.size(), expand_times.size(),
33-
"The number of expandTimes's value must be equal "
34-
"to the rank of X.");
27+
void InferShape(framework::InferShapeContext* ctx) const override {
28+
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must be initialized.");
29+
std::vector<int> expand_times =
30+
ctx->Attrs().Get<std::vector<int>>("expandTimes");
31+
auto x_dims = ctx->GetInputDim("X");
32+
33+
PADDLE_ENFORCE_EQ(static_cast<size_t>(x_dims.size()), expand_times.size(),
34+
"The number of Attr(expandTimes)'s value must be equal "
35+
"to the rank of Input(X).");
3536
PADDLE_ENFORCE_LE(x_dims.size(), 6,
36-
"The rank of X must not be greater than 6.");
37+
"The rank of Input(X) must not be greater than 6.");
3738

3839
std::vector<int64_t> out_shape(x_dims.size());
3940
for (size_t i = 0; i < expand_times.size(); ++i) {
4041
PADDLE_ENFORCE_GE(expand_times[i], 1,
41-
"Each value of expandTimes should not be "
42+
"Each value of Attr(expandTimes) should not be "
4243
"less than 1.");
4344
out_shape[i] = x_dims[i] * expand_times[i];
4445
}
45-
auto* out = ctx.Output<framework::LoDTensor>("Out");
46-
out->Resize(framework::make_ddim(out_shape));
46+
47+
ctx->SetOutputDim("Out", framework::make_ddim(out_shape));
48+
ctx->ShareLoD("X", "Out");
4749
}
4850
};
4951

@@ -52,20 +54,21 @@ class ExpandOpMaker : public framework::OpProtoAndCheckerMaker {
5254
ExpandOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
5355
: OpProtoAndCheckerMaker(proto, op_checker) {
5456
AddInput("X",
55-
"The input tensor of expand op."
56-
"The rank of X should be between in 1 and 6.");
57+
"(Tensor, default Tensor<float>) A tensor with rank in [1, 6]."
58+
"X is the input tensor to be expanded.");
5759
AddOutput("Out",
58-
"Output tensor of expand op."
59-
"The rank of Out is same as X except that each dimension size "
60-
"of Out equals to corresponding dimension size of X multiplying "
61-
"corresponding value of expandTimes.");
60+
"(Tensor, default Tensor<float>) A tensor with rank in [1, 6]."
61+
"The rank of Output(Out) is same as Input(X) except that each "
62+
"dimension size of Output(Out) is equal to corresponding "
63+
"dimension size of Input(X) multiplying corresponding value of "
64+
"Attr(expandTimes).");
6265
AddAttr<std::vector<int>>("expandTimes",
6366
"Expand times number for each dimension.");
6467
AddComment(R"DOC(
6568
Expand operator tiles the input by given times number. You should set times
6669
number for each dimension by providing attribute 'expandTimes'. The rank of X
67-
should be between in 1 and 6. Please notice that size of 'expandTimes' must be
68-
same with X's rank.
70+
should be in [1, 6]. Please notice that size of 'expandTimes' must be same with
71+
X's rank.
6972
)DOC");
7073
}
7174
};
@@ -75,25 +78,27 @@ class ExpandGradOp : public framework::OperatorWithKernel {
7578
using framework::OperatorWithKernel::OperatorWithKernel;
7679

7780
protected:
78-
void InferShape(const framework::InferShapeContext& ctx) const override {
79-
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "X must be initialized.");
80-
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
81-
"Input(Out@GRAD) should not be null.");
82-
auto x_dims = ctx.Input<Tensor>("X")->dims();
83-
std::vector<int> expand_times = Attr<std::vector<int>>("expandTimes");
84-
auto out_dims =
85-
ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"))->dims();
86-
auto* x_grad =
87-
ctx.Output<framework::LoDTensor>(framework::GradVarName("X"));
81+
void InferShape(framework::InferShapeContext* ctx) const override {
82+
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null.");
83+
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
84+
"Input(Out@GRAD) should not be null.");
85+
auto x_dims = ctx->GetInputDim("X");
86+
std::vector<int> expand_times =
87+
ctx->Attrs().Get<std::vector<int>>("expandTimes");
88+
auto out_dims = ctx->GetInputDim(framework::GradVarName("Out"));
8889

8990
for (size_t i = 0; i < expand_times.size(); ++i) {
9091
PADDLE_ENFORCE_EQ(x_dims[i] * expand_times[i], out_dims[i],
9192
"Each dimension size of Input(Out@GRAD) should be "
9293
"equal to multiplication of crroresponding dimension "
93-
"size of Input(X) and expandTimes value.");
94+
"size of Input(X) and Attr(expandTimes) value.");
9495
}
9596

96-
if (x_grad) x_grad->Resize(x_dims);
97+
auto x_grad_name = framework::GradVarName("X");
98+
99+
if (ctx->HasOutput(x_grad_name)) {
100+
ctx->SetOutputDim(x_grad_name, x_dims);
101+
}
97102
}
98103
};
99104

paddle/operators/expand_op.h

Lines changed: 18 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@
4545
namespace paddle {
4646
namespace operators {
4747

48+
using Tensor = framework::Tensor;
49+
4850
template <typename T, int MajorType = Eigen::RowMajor,
4951
typename IndexType = Eigen::DenseIndex>
5052
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
@@ -53,24 +55,24 @@ template <typename T, size_t D, int MajorType = Eigen::RowMajor,
5355
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
5456

5557
template <typename Place, typename T>
56-
class ExpandKernel : public framework::OpKernel {
58+
class ExpandKernel : public framework::OpKernel<T> {
5759
public:
5860
void Compute(const framework::ExecutionContext& context) const override {
59-
auto rank = context.Input<framework::Tensor>("X")->dims().size();
61+
auto rank = context.Input<Tensor>("X")->dims().size();
6062
switch (rank) {
6163
REP_EXPAND_TEMPLATE(6)
6264
default:
6365
PADDLE_ENFORCE(false,
6466
"Only support tensor with rank being between 1 and 6.");
65-
};
67+
}
6668
}
6769

6870
protected:
6971
template <int Rank>
7072
void Expand(const framework::ExecutionContext& context) const {
71-
auto* in0 = context.Input<framework::Tensor>("X");
73+
auto* in0 = context.Input<Tensor>("X");
7274
auto& expand_times = context.Attr<std::vector<int>>("expandTimes");
73-
auto* out0 = context.Output<framework::LoDTensor>("Out");
75+
auto* out0 = context.Output<Tensor>("Out");
7476
Eigen::DSizes<int, Rank> bcast_dims;
7577
auto x_dims = in0->dims();
7678
for (size_t i = 0; i < expand_times.size(); ++i) {
@@ -85,10 +87,10 @@ class ExpandKernel : public framework::OpKernel {
8587
};
8688

8789
template <typename Place, typename T>
88-
class ExpandGradKernel : public framework::OpKernel {
90+
class ExpandGradKernel : public framework::OpKernel<T> {
8991
public:
9092
void Compute(const framework::ExecutionContext& context) const override {
91-
auto* in0 = context.Input<framework::Tensor>("X");
93+
auto* in0 = context.Input<Tensor>("X");
9294
auto& expand_times = context.Attr<std::vector<int>>("expandTimes");
9395
auto x_dims = in0->dims();
9496
std::vector<int> reshape_dims_vec;
@@ -111,23 +113,17 @@ class ExpandGradKernel : public framework::OpKernel {
111113
int dims = reshape_dims_vec.size() * 6 + reduce_dims_vec.size() - 7;
112114
// no need reduce, just copy
113115
if (reduce_dims_vec.size() == 0) {
114-
auto* in0 =
115-
context.Input<framework::LoDTensor>(framework::GradVarName("Out"));
116-
auto* out0 =
117-
context.Output<framework::LoDTensor>(framework::GradVarName("X"));
116+
auto* in0 = context.Input<Tensor>(framework::GradVarName("Out"));
117+
auto* out0 = context.Output<Tensor>(framework::GradVarName("X"));
118118
out0->mutable_data<T>(context.GetPlace());
119-
if (platform::is_cpu_place(context.GetPlace())) {
120-
out0->CopyFrom<T>(*in0, platform::CPUPlace());
121-
} else {
122-
out0->CopyFrom<T>(*in0, platform::GPUPlace());
123-
}
119+
out0->CopyFrom(*in0, context.GetPlace(), context.device_context());
124120
} else {
125121
switch (dims) {
126122
REP_EXPAND_GRAD_TEMPLATE(72)
127123
default:
128124
PADDLE_ENFORCE(
129125
false, "Only support tensor with rank being between 1 and 6.");
130-
};
126+
}
131127
}
132128
}
133129

@@ -144,11 +140,9 @@ class ExpandGradKernel : public framework::OpKernel {
144140
PADDLE_ENFORCE_EQ(reduce_size, reduce_dims_vec.size(),
145141
"Inconsistent size between template Dims and "
146142
"reduce dimensions.");
147-
auto* in0 =
148-
context.Input<framework::LoDTensor>(framework::GradVarName("Out"));
149-
auto* out0 =
150-
context.Output<framework::LoDTensor>(framework::GradVarName("X"));
151-
auto x = EigenVector<T>::Flatten(*(context.Input<framework::Tensor>("X")));
143+
auto* in0 = context.Input<Tensor>(framework::GradVarName("Out"));
144+
auto* out0 = context.Output<Tensor>(framework::GradVarName("X"));
145+
auto x = EigenVector<T>::Flatten(*(context.Input<Tensor>("X")));
152146
out0->mutable_data<T>(context.GetPlace());
153147
auto x_grad = EigenVector<T>::Flatten(*out0);
154148
Eigen::DSizes<int, Dims / 6 + 1> reshape_dims;
@@ -165,5 +159,5 @@ class ExpandGradKernel : public framework::OpKernel {
165159
}
166160
};
167161

168-
} // operators
169-
} // paddle
162+
} // namespace operators
163+
} // namespace paddle

0 commit comments

Comments
 (0)