@@ -22,17 +22,16 @@ class LoDResetOp : public framework::OperatorWithKernel {
22
22
using framework::OperatorWithKernel::OperatorWithKernel;
23
23
24
24
void InferShape (framework::InferShapeContext *ctx) const override {
25
- // input check
26
25
PADDLE_ENFORCE (ctx->HasInput (" X" ),
27
26
" Input(X) of LoDResetOp should not be null." );
28
27
PADDLE_ENFORCE (ctx->HasOutput (" Out" ),
29
28
" Output(Out) of LoDResetOp should not be null." );
30
- // If target LoD is not set form Input(), then it must be set from Attr().
31
- if (!ctx->HasInput (" TargetLoD " )) {
29
+
30
+ if (!ctx->HasInput (" Y " )) {
32
31
auto level0 = ctx->Attrs ().Get <std::vector<int >>(" target_lod" );
33
- PADDLE_ENFORCE (level0.size () > 1 ,
34
- " Target LoD is not found, should be set to be a valid one "
35
- " through Input() or Attr() ." );
32
+ PADDLE_ENFORCE_GT (level0.size (), 1 ,
33
+ " If Input(Y) not provided, the target lod should be "
34
+ " specified by attribute `target_lod` ." );
36
35
}
37
36
ctx->SetOutputDim (" Out" , ctx->GetInputDim (" X" ));
38
37
}
@@ -50,36 +49,77 @@ class LoDResetOpMaker : public framework::OpProtoAndCheckerMaker {
50
49
public:
51
50
LoDResetOpMaker (OpProto *proto, OpAttrChecker *op_checker)
52
51
: OpProtoAndCheckerMaker(proto, op_checker) {
53
- AddInput (" X" , " (LoDTensor) The input tensor of lod_reset operator." );
54
- AddInput (" TargetLoD" ,
55
- " (Tensor, optional) The target level 0 LoD from Input()." )
52
+ AddInput (" X" ,
53
+ " (Tensor, LoDTensor) Input variable of LoDResetOp which "
54
+ " could be a Tensor or LoDTensor, where the data of output "
55
+ " variable inherits from." );
56
+ AddInput (" Y" ,
57
+ " (Tensor, LoDTensor, optional) If provided and Y is LoDTensor, "
58
+ " lod of Input(Y) would be considered as the target lod first, "
59
+ " otherwise data of Input(Y) would be considered as the "
60
+ " target lod." )
56
61
.AsDispensable ();
57
- AddOutput (" Out" , " (LoDTensor) The output tensor of lod_reset operator." );
62
+ AddOutput (" Out" ,
63
+ " (LoDTensor) Output variable of LoDResetOp which should be a "
64
+ " LoDTensor." );
58
65
AddAttr<std::vector<int >>(" target_lod" ,
59
66
" The target level 0 LoD from Attr()." )
60
67
.SetDefault (std::vector<int >{});
61
68
AddComment (R"DOC( LoDReset operator
62
69
63
- Reset LoD of Input(X) into a new one specified by Input(TargetLoD) or
64
- Attr(target_lod), or set LoD for Input(X) if it doesn't have one.
65
- Currently the lod_reset operator only supports the reset of level 0 LoD.
66
- At least one of Input(TargetLoD) and Attr(target_lod) must be set,
67
- and if both of them are set, Input(TargetLoD) will be chosen as the
68
- target LoD.
70
+ Set LoD of `X` to a new one specified by `Y` or attribute `target_lod`. When `Y`
71
+ provided and `Y` is a LoDTensor, `Y.lod` would be considered as target LoD
72
+ first, otherwise `Y.data` would be considered as target LoD. If `Y` is not
73
+ provided, target LoD should be specified by attribute `target_lod`.
74
+ If target LoD is specified by `Y.data` or `target_lod`, only one level LoD
75
+ is supported.
76
+
77
+ Example 1:
78
+
79
+ Given a 1-level LoDTensor input(X):
80
+ X.lod = [[ 0, 2, 5 6 ]]
81
+ X.data = [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]]
82
+ X.dims = [6, 1]
83
+
84
+ attr(target_lod): [0, 4, 6]
85
+
86
+ then we get a 1-level LoDTensor:
87
+ Out.lod = [[ 0, 4, 6 ]]
88
+ Out.data = [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]]
89
+ Out.dims = [6, 1]
90
+
91
+ Example 2:
69
92
70
- An example:
71
- Given a float LoDTensor X with shape (6, 1), its transpose form represents
93
+ Given a 1-level LoDTensor input(X):
94
+ X.lod = [[ 0, 2, 5 6 ]]
95
+ X.data = [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]]
96
+ X.dims = [6, 1]
72
97
73
- [1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
98
+ input(Y) is a Tensor:
99
+ Y.data = [[0, 2, 6]]
100
+ Y.dims = [1, 3]
74
101
75
- with LoD = [[0, 2, 5, 6]] and the three (transposed) sequences look like
102
+ then we get a 1-level LoDTensor:
103
+ Out.lod = [[ 0, 2, 6 ]]
104
+ Out.data = [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]]
105
+ Out.dims = [6, 1]
76
106
77
- [1.0, 2.0], [3.0, 4.0, 5.0], [6.0].
107
+ Example 3:
78
108
79
- If target LoD = [0, 4, 6], the lod_reset operator will reset the LoD and
80
- the sequences that the LoDTensor Output(Out) contains becomes:
109
+ Given a 1-level LoDTensor input(X):
110
+ X.lod = [[ 0, 2, 5 6 ]]
111
+ X.data = [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]]
112
+ X.dims = [6, 1]
81
113
82
- [1.0, 2.0, 3.0, 4.0], [5.0, 6.0].
114
+ input(Y) is a 2-level LoDTensor:
115
+ Y.lod = [[0, 2, 4], [0, 2, 5, 6]]
116
+ Y.data = [[1.1], [2.1], [3.1], [4.1], [5.1], [6.1]]
117
+ Y.dims = [6, 1]
118
+
119
+ then we get a 2-level LoDTensor:
120
+ Out.lod = [[0, 2, 4], [0, 2, 5, 6]]
121
+ Out.data = [[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]]
122
+ Out.dims = [6, 1]
83
123
84
124
)DOC" );
85
125
}
@@ -90,10 +130,16 @@ class LoDResetGradOp : public framework::OperatorWithKernel {
90
130
using framework::OperatorWithKernel::OperatorWithKernel;
91
131
92
132
void InferShape (framework::InferShapeContext *ctx) const override {
93
- PADDLE_ENFORCE (ctx->HasInput (" X" ), " Input(X) shouldn't be null." );
133
+ PADDLE_ENFORCE (ctx->HasInput (" X" ),
134
+ " Input(X) of LoDResetGradOp should not be null." );
94
135
PADDLE_ENFORCE (ctx->HasInput (framework::GradVarName (" Out" )),
95
- " Input(Out@GRAD) shouldn't be null." );
96
- ctx->SetOutputDim (framework::GradVarName (" X" ), ctx->GetInputDim (" X" ));
136
+ " Input(Out@Grad) of LoDResetGradOp should not be null." );
137
+
138
+ auto x_grad_name = framework::GradVarName (" X" );
139
+ if (ctx->HasOutput (x_grad_name)) {
140
+ ctx->SetOutputDim (x_grad_name, ctx->GetInputDim (" X" ));
141
+ ctx->ShareLoD (" X" , /* ->*/ x_grad_name);
142
+ }
97
143
}
98
144
99
145
protected:
@@ -111,9 +157,13 @@ class LoDResetGradOp : public framework::OperatorWithKernel {
111
157
namespace ops = paddle::operators;
112
158
REGISTER_OP (lod_reset, ops::LoDResetOp, ops::LoDResetOpMaker, lod_reset_grad,
113
159
ops::LoDResetGradOp);
114
- REGISTER_OP_CPU_KERNEL (lod_reset,
115
- ops::LoDResetKernel<paddle::platform::CPUPlace, float >,
116
- ops::LoDResetKernel<paddle::platform::CPUPlace, double >);
160
+ REGISTER_OP_CPU_KERNEL (
161
+ lod_reset, ops::LoDResetKernel<paddle::platform::CPUPlace, float >,
162
+ ops::LoDResetKernel<paddle::platform::CPUPlace, double >,
163
+ ops::LoDResetKernel<paddle::platform::CPUPlace, int >,
164
+ ops::LoDResetKernel<paddle::platform::CPUPlace, int64_t >);
117
165
REGISTER_OP_CPU_KERNEL (
118
166
lod_reset_grad, ops::LoDResetGradKernel<paddle::platform::CPUPlace, float >,
119
- ops::LoDResetGradKernel<paddle::platform::CPUPlace, double >);
167
+ ops::LoDResetGradKernel<paddle::platform::CPUPlace, double >,
168
+ ops::LoDResetGradKernel<paddle::platform::CPUPlace, int >,
169
+ ops::LoDResetGradKernel<paddle::platform::CPUPlace, int64_t >);
0 commit comments