@@ -29,10 +29,12 @@ class SequencePadOp : public framework::OperatorWithKernel {
29
29
" Input(PadValue) of SequencePadOp should not be null." );
30
30
PADDLE_ENFORCE (ctx->HasOutput (" Out" ),
31
31
" Output(Out) of SequencePadOp should not be null." );
32
+ PADDLE_ENFORCE (ctx->HasOutput (" Length" ),
33
+ " Output(Length) of SequencePadOp should not be null." );
32
34
33
35
auto x_dims = ctx->GetInputDim (" X" );
34
36
PADDLE_ENFORCE_GE (x_dims.size (), 2 ,
35
- " The rank of Input(x ) can't be less than 2." );
37
+ " The rank of Input(X ) can't be less than 2." );
36
38
auto time_step_dims = framework::slice_ddim (x_dims, 1 , x_dims.size ());
37
39
auto pad_value_dims = ctx->GetInputDim (" PadValue" );
38
40
PADDLE_ENFORCE (pad_value_dims == framework::make_ddim ({1 }) ||
@@ -41,8 +43,8 @@ class SequencePadOp : public framework::OperatorWithKernel {
41
43
" shape equals to time steps in sequences" );
42
44
43
45
int out_dim_0 = -1 ;
44
- int out_dim_1 = -1 ;
45
46
47
+ int padded_length = ctx->Attrs ().Get <int >(" padded_length" );
46
48
if (ctx->IsRuntime ()) {
47
49
// run time
48
50
framework::Variable* x_var =
@@ -58,27 +60,37 @@ class SequencePadOp : public framework::OperatorWithKernel {
58
60
59
61
int seq_num = x_lod_0.size () - 1 ;
60
62
int max_seq_len = math::MaximumSequenceLength (x_lod_0);
61
- int padded_length = ctx->Attrs ().Get <int >(" padded_length" );
62
63
if (padded_length == -1 ) {
63
64
padded_length = max_seq_len;
64
65
}
65
66
PADDLE_ENFORCE_GE (padded_length, max_seq_len,
66
67
" The Attr(padded_length) must be -1 or an int greater "
67
68
" than the length of the longest original sequence." );
68
69
out_dim_0 = seq_num;
69
- out_dim_1 = padded_length;
70
70
} else {
71
71
// compile time
72
+ if (padded_length == -1 ) {
73
+ padded_length = 1 ;
74
+ }
72
75
framework::VarDesc* x_desc =
73
76
boost::get<framework::VarDesc*>(ctx->GetInputVarPtrs (" X" )[0 ]);
74
77
PADDLE_ENFORCE_GE (x_desc->GetLoDLevel (), 1 );
75
78
}
76
79
77
- std::vector<int > out_dims_vec{out_dim_0, out_dim_1};
80
+ std::vector<int > out_dims_vec{out_dim_0, padded_length};
81
+ std::vector<int > len_dims_vec{out_dim_0, 1 };
78
82
auto time_step_dims_vec = framework::vectorize2int (time_step_dims);
79
83
out_dims_vec.insert (out_dims_vec.end (), time_step_dims_vec.begin (),
80
84
time_step_dims_vec.end ());
81
85
ctx->SetOutputDim (" Out" , framework::make_ddim (out_dims_vec));
86
+ ctx->SetOutputDim (" Length" , framework::make_ddim (len_dims_vec));
87
+ }
88
+
89
+ protected:
90
+ framework::OpKernelType GetExpectedKernelType (
91
+ const framework::ExecutionContext& ctx) const override {
92
+ auto data_type = framework::GetDataTypeOfVar (ctx.InputVar (" X" ));
93
+ return framework::OpKernelType (data_type, ctx.device_context ());
82
94
}
83
95
};
84
96
@@ -96,6 +108,10 @@ class SequencePadOpMaker : public framework::OpProtoAndCheckerMaker {
96
108
AddOutput (
97
109
" Out" ,
98
110
" (LoDTensor) The output vairable, which contains padded sequences." );
111
+ AddOutput (
112
+ " Length" ,
113
+ " (LoDTensor) The output vairable, which contains the actual length of "
114
+ " sequences before padding." );
99
115
AddAttr<int >(
100
116
" padded_length" ,
101
117
" The length of padded sequences. It can be setted to -1 or "
@@ -125,6 +141,7 @@ class SequencePadOpMaker : public framework::OpProtoAndCheckerMaker {
125
141
then we get LoDTensor:
126
142
Out.data = [[a, b, 0, 0],
127
143
[c, d, e, 0]]
144
+ Length.data = [[2], [3]]
128
145
129
146
Case 2:
130
147
@@ -138,7 +155,8 @@ class SequencePadOpMaker : public framework::OpProtoAndCheckerMaker {
138
155
then we get LoDTensor:
139
156
Out.data = [[[a1, a2], [b1, b2], [0, 0]],
140
157
[[c1, c2], [d1, d2], [e1, e2]]]
141
-
158
+ Length.data = [[2], [3]]
159
+
142
160
Case 3:
143
161
144
162
Given a 1-level LoDTensor input(X):
@@ -151,6 +169,7 @@ class SequencePadOpMaker : public framework::OpProtoAndCheckerMaker {
151
169
then we get LoDTensor:
152
170
Out.data = [[[a1, a2], [b1, b2], [p1, p2]],
153
171
[[c1, c2], [d1, d2], [e1, e2]]]
172
+ Length.data = [[2], [3]]
154
173
155
174
)DOC" );
156
175
}
@@ -171,6 +190,13 @@ class SequencePadGradOp : public framework::OperatorWithKernel {
171
190
ctx->ShareLoD (" X" , /* ->*/ framework::GradVarName (" X" ));
172
191
}
173
192
}
193
+
194
+ protected:
195
+ framework::OpKernelType GetExpectedKernelType (
196
+ const framework::ExecutionContext& ctx) const override {
197
+ auto data_type = framework::GetDataTypeOfVar (ctx.InputVar (" X" ));
198
+ return framework::OpKernelType (data_type, ctx.device_context ());
199
+ }
174
200
};
175
201
176
202
} // namespace operators
0 commit comments