Skip to content

Commit 8d4e2d4

Browse files
committed
1. Add unitest for empty sequence case
2. Fix comments and paddle enforce check
1 parent 9f32b61 commit 8d4e2d4

File tree

2 files changed

+38
-11
lines changed

2 files changed

+38
-11
lines changed

paddle/operators/seq_expand_op.cc

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,8 @@ class SeqExpandOp : public framework::OperatorWithKernel {
2525

2626
protected:
2727
void InferShape(framework::InferShapeContext* ctx) const override {
28-
PADDLE_ENFORCE(ctx->HasInput("X"),
29-
"Input(X) of SeqExpandOp should not be null.");
30-
PADDLE_ENFORCE(ctx->HasOutput("Out"),
31-
"Output(Out) of SeqExpandOp should not be null.");
28+
PADDLE_ENFORCE(ctx->HasInput("X"));
29+
PADDLE_ENFORCE(ctx->HasOutput("Out"));
3230
PADDLE_ENFORCE(
3331
ctx->HasInput("Y"),
3432
"Input(Y) of SeqExpandOp should not be null while repeat == 0.");
@@ -54,7 +52,7 @@ class SeqExpandOpMaker : public framework::OpProtoAndCheckerMaker {
5452
"The element numbers of last level in input('Y') "
5553
"must be equal to dims[0] of input('X').");
5654
AddOutput("Out",
57-
"The output of seq_expand op."
55+
"(LodTensor)The output of seq_expand op."
5856
"The lod of output will be as same as input(Y)'s lod.");
5957
AddComment(R"DOC(
6058
Expand input(X) according to LOD of input(Y).
@@ -69,6 +67,7 @@ Given 2-level a LoDTensor input(X)
6967
and input(Y)
7068
Y.lod = [[0, 2, 4],
7169
[0, 3, 6, 7, 8]]
70+
with condition len(Y.lod[-1]) -1 == X.dims[0]
7271
then we get 2-level LoDTensor
7372
Out.lod = [[0, 2, 4],
7473
[0, 3, 6, 7, 8]]
@@ -83,6 +82,7 @@ Given a 0-level LoDTensor input(X)
8382
X.dims = [3, 1]
8483
and input(Y)
8584
Y.lod = [[0, 2, 3, 6]]
85+
with condition len(Y.lod[-1]) -1 == X.dims[0]
8686
then we get 1-level LoDTensor
8787
Out.lod = [[0, 2, 3, 6]]
8888
Out.data = [a, a, b, c, c, c]
@@ -96,11 +96,29 @@ Given a 0-level LoDTensor input(X)
9696
X.dims = [3, 2]
9797
and input(Y)
9898
Y.lod = [[0, 2, 3, 6]]
99+
with condition len(Y.lod[-1]) -1 == X.dims[0]
99100
then we get 1-level LoDTensor
100101
Out.lod = [[0, 2, 3, 6]]
101102
Out.data = [[a,b], [a,b] [c,d], [e, f], [e, f], [e, f]]
102103
Out.dims = [6, 2]
103104
105+
Case 4:
106+
107+
Given 2-level a LoDTensor input(X)
108+
X.lod = [[0, 2, 3],
109+
[0, 1, 3, 4]]
110+
X.data = [a, b, c, d]
111+
X.dims = [4, 1]
112+
and input(Y)
113+
Y.lod = [[0, 2, 4],
114+
[0, 3, 6, 6, 8]]
115+
with condition len(Y.lod[-1]) -1 == X.dims[0]
116+
then we get 2-level LoDTensor
117+
Out.lod = [[0, 2, 4],
118+
[0, 3, 6, 6, 8]]
119+
Out.data = [a, a, a, b, b, b, d, d]
120+
Out.dims = [8, 1]
121+
104122
105123
)DOC");
106124
}
@@ -112,8 +130,8 @@ class SeqExpandOpGrad : public framework::OperatorWithKernel {
112130

113131
protected:
114132
void InferShape(framework::InferShapeContext* ctx) const override {
115-
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
116-
PADDLE_ENFORCE(ctx->HasInput("Out"), "Input(Out) should not be null");
133+
PADDLE_ENFORCE(ctx->HasInput("X"));
134+
PADDLE_ENFORCE(ctx->HasInput("Out"));
117135
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
118136
"Input(Out@GRAD) should not be null");
119137
auto x_dims = ctx->GetInputDim("X");

paddle/operators/seq_expand_op.h

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ class SeqExpandKernel : public framework::OpKernel<T> {
3636
"The size of last lod level in Input(Y)"
3737
"must be equal to dims[0] of Input(X).");
3838
out->set_lod(y->lod());
39-
out->Resize(y->dims());
4039
auto place = context.GetEigenDevice<Place>();
4140
size_t element_len = framework::product(x_dims) / x_dims[0];
4241
T* out_data = out->mutable_data<T>(context.GetPlace());
@@ -57,6 +56,18 @@ class SeqExpandKernel : public framework::OpKernel<T> {
5756
}
5857
};
5958

59+
/*
60+
*Given Grad(Out)
61+
*
62+
* Grad(Out).lod = [[0, 2],
63+
* [0, 3, 6]]
64+
* Grad(Out).data = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
65+
* Then
66+
* Grad(X).data = [(0.1 + 0.2 + 0.3), (0.4 + 0.5 + 0.6)]
67+
* = [0.6, 1.5]
68+
* Grad(X).lod = Input(X).lod
69+
*
70+
* */
6071
template <typename Place, typename T>
6172
class SeqExpandGradKernel : public framework::OpKernel<T> {
6273
public:
@@ -68,10 +79,8 @@ class SeqExpandGradKernel : public framework::OpKernel<T> {
6879
auto out_last_level = out->lod().back();
6980
d_x->set_lod(x->lod());
7081
const T* d_out_data = d_out->data<T>();
71-
auto d_out_dims = d_out->dims();
7282
T* d_x_data = d_x->mutable_data<T>(context.GetPlace());
73-
size_t element_len = framework::product(d_out_dims) / d_out_dims[0];
74-
83+
size_t element_len = d_out->numel() / d_out->dims()[0];
7584
for (size_t i = 0; i < out_last_level.size() - 1; ++i) {
7685
size_t repeat = out_last_level[i + 1] - out_last_level[i];
7786
Eigen::TensorMap<

0 commit comments

Comments
 (0)