@@ -24,26 +24,28 @@ class ExpandOp : public framework::OperatorWithKernel {
24
24
using framework::OperatorWithKernel::OperatorWithKernel;
25
25
26
26
protected:
27
- void InferShape (const framework::InferShapeContext& ctx) const override {
28
- PADDLE_ENFORCE_NOT_NULL (ctx.InputVar (" X" ), " X must be initialized." );
29
- std::vector<int > expand_times = Attr<std::vector<int >>(" expandTimes" );
30
- auto x_dims = ctx.Input <Tensor>(" X" )->dims ();
31
-
32
- PADDLE_ENFORCE_EQ (x_dims.size (), expand_times.size (),
33
- " The number of expandTimes's value must be equal "
34
- " to the rank of X." );
27
+ void InferShape (framework::InferShapeContext* ctx) const override {
28
+ PADDLE_ENFORCE (ctx->HasInput (" X" ), " Input(X) must be initialized." );
29
+ std::vector<int > expand_times =
30
+ ctx->Attrs ().Get <std::vector<int >>(" expandTimes" );
31
+ auto x_dims = ctx->GetInputDim (" X" );
32
+
33
+ PADDLE_ENFORCE_EQ (static_cast <size_t >(x_dims.size ()), expand_times.size (),
34
+ " The number of Attr(expandTimes)'s value must be equal "
35
+ " to the rank of Input(X)." );
35
36
PADDLE_ENFORCE_LE (x_dims.size (), 6 ,
36
- " The rank of X must not be greater than 6." );
37
+ " The rank of Input(X) must not be greater than 6." );
37
38
38
39
std::vector<int64_t > out_shape (x_dims.size ());
39
40
for (size_t i = 0 ; i < expand_times.size (); ++i) {
40
41
PADDLE_ENFORCE_GE (expand_times[i], 1 ,
41
- " Each value of expandTimes should not be "
42
+ " Each value of Attr( expandTimes) should not be "
42
43
" less than 1." );
43
44
out_shape[i] = x_dims[i] * expand_times[i];
44
45
}
45
- auto * out = ctx.Output <framework::LoDTensor>(" Out" );
46
- out->Resize (framework::make_ddim (out_shape));
46
+
47
+ ctx->SetOutputDim (" Out" , framework::make_ddim (out_shape));
48
+ ctx->ShareLoD (" X" , " Out" );
47
49
}
48
50
};
49
51
@@ -52,20 +54,21 @@ class ExpandOpMaker : public framework::OpProtoAndCheckerMaker {
52
54
ExpandOpMaker (framework::OpProto* proto, framework::OpAttrChecker* op_checker)
53
55
: OpProtoAndCheckerMaker(proto, op_checker) {
54
56
AddInput (" X" ,
55
- " The input tensor of expand op ."
56
- " The rank of X should be between in 1 and 6 ." );
57
+ " (Tensor, default Tensor<float>) A tensor with rank in [1, 6] ."
58
+ " X is the input tensor to be expanded ." );
57
59
AddOutput (" Out" ,
58
- " Output tensor of expand op."
59
- " The rank of Out is same as X except that each dimension size "
60
- " of Out equals to corresponding dimension size of X multiplying "
61
- " corresponding value of expandTimes." );
60
+ " (Tensor, default Tensor<float>) A tensor with rank in [1, 6]."
61
+ " The rank of Output(Out) is same as Input(X) except that each "
62
+ " dimension size of Output(Out) is equal to corresponding "
63
+ " dimension size of Input(X) multiplying corresponding value of "
64
+ " Attr(expandTimes)." );
62
65
AddAttr<std::vector<int >>(" expandTimes" ,
63
66
" Expand times number for each dimension." );
64
67
AddComment (R"DOC(
65
68
Expand operator tiles the input by given times number. You should set times
66
69
number for each dimension by providing attribute 'expandTimes'. The rank of X
67
- should be between in 1 and 6 . Please notice that size of 'expandTimes' must be
68
- same with X's rank.
70
+ should be in [1, 6] . Please notice that size of 'expandTimes' must be same with
71
+ X's rank.
69
72
)DOC" );
70
73
}
71
74
};
@@ -75,25 +78,27 @@ class ExpandGradOp : public framework::OperatorWithKernel {
75
78
using framework::OperatorWithKernel::OperatorWithKernel;
76
79
77
80
protected:
78
- void InferShape (const framework::InferShapeContext& ctx) const override {
79
- PADDLE_ENFORCE_NOT_NULL (ctx.InputVar (" X" ), " X must be initialized." );
80
- PADDLE_ENFORCE_NOT_NULL (ctx.InputVar (framework::GradVarName (" Out" )),
81
- " Input(Out@GRAD) should not be null." );
82
- auto x_dims = ctx.Input <Tensor>(" X" )->dims ();
83
- std::vector<int > expand_times = Attr<std::vector<int >>(" expandTimes" );
84
- auto out_dims =
85
- ctx.Input <framework::LoDTensor>(framework::GradVarName (" Out" ))->dims ();
86
- auto * x_grad =
87
- ctx.Output <framework::LoDTensor>(framework::GradVarName (" X" ));
81
+ void InferShape (framework::InferShapeContext* ctx) const override {
82
+ PADDLE_ENFORCE (ctx->HasInput (" X" ), " Input(X) should not be null." );
83
+ PADDLE_ENFORCE (ctx->HasInput (framework::GradVarName (" Out" )),
84
+ " Input(Out@GRAD) should not be null." );
85
+ auto x_dims = ctx->GetInputDim (" X" );
86
+ std::vector<int > expand_times =
87
+ ctx->Attrs ().Get <std::vector<int >>(" expandTimes" );
88
+ auto out_dims = ctx->GetInputDim (framework::GradVarName (" Out" ));
88
89
89
90
for (size_t i = 0 ; i < expand_times.size (); ++i) {
90
91
PADDLE_ENFORCE_EQ (x_dims[i] * expand_times[i], out_dims[i],
91
92
" Each dimension size of Input(Out@GRAD) should be "
92
93
" equal to multiplication of crroresponding dimension "
93
- " size of Input(X) and expandTimes value." );
94
+ " size of Input(X) and Attr( expandTimes) value." );
94
95
}
95
96
96
- if (x_grad) x_grad->Resize (x_dims);
97
+ auto x_grad_name = framework::GradVarName (" X" );
98
+
99
+ if (ctx->HasOutput (x_grad_name)) {
100
+ ctx->SetOutputDim (x_grad_name, x_dims);
101
+ }
97
102
}
98
103
};
99
104
0 commit comments