@@ -21,82 +21,85 @@ class SequencePadOp : public framework::OperatorWithKernel {
21
21
public:
22
22
using framework::OperatorWithKernel::OperatorWithKernel;
23
23
24
+ protected:
24
25
void InferShape (framework::InferShapeContext* ctx) const override {
25
26
PADDLE_ENFORCE (ctx->HasInput (" X" ),
26
27
" Input(X) of SequencePadOp should not be null." );
28
+ PADDLE_ENFORCE (ctx->HasInput (" PadValue" ),
29
+ " Input(PadValue) of SequencePadOp should not be null." );
27
30
PADDLE_ENFORCE (ctx->HasOutput (" Out" ),
28
31
" Output(Out) of SequencePadOp should not be null." );
29
32
30
33
auto x_dims = ctx->GetInputDim (" X" );
34
+ PADDLE_ENFORCE_GE (x_dims.size (), 2 ,
35
+ " The rank of Input(x) can't be less than 2." );
36
+ auto time_step_dims = framework::slice_ddim (x_dims, 1 , x_dims.size ());
37
+ auto pad_value_dims = ctx->GetInputDim (" PadValue" );
38
+ PADDLE_ENFORCE (pad_value_dims == framework::make_ddim ({1 }) ||
39
+ pad_value_dims == time_step_dims,
40
+ " The Input(PadValue) must be a scalar or a tensor whose "
41
+ " shape equals to time steps in sequences" );
31
42
32
- PADDLE_ENFORCE_EQ (x_dims.size (), 2 ,
33
- " Only support 2-D tensor, rank of Input(X) should be 2." );
34
-
35
- int lod_level = ctx->Attrs ().Get <int >(" lod_level" );
36
-
37
- int64_t max_len = -1 ;
38
- int64_t seq_num = -1 ;
39
- int x_lod_size = -1 ;
43
+ int batch_dim_size = -1 ;
40
44
41
45
if (ctx->IsRuntime ()) {
46
+ // run time
42
47
framework::Variable* x_var =
43
48
boost::get<framework::Variable*>(ctx->GetInputVarPtrs (" X" )[0 ]);
44
-
45
- auto & x_lod = x_var->Get <LoDTensor>().lod ();
46
-
47
- x_lod_size = x_lod.size ();
48
-
49
- auto x_abs_offset = framework::ToAbsOffset (x_lod)[lod_level];
50
-
51
- PADDLE_ENFORCE_EQ (x_dims[0 ], static_cast <int64_t >(x_abs_offset.back ()),
52
- " The first dimension of `X` should be equal to sum "
53
- " of all sequences' length." );
54
-
55
- seq_num = x_abs_offset.size () - 1 ;
56
-
57
- for (int64_t i = 1 ; i <= seq_num; ++i) {
58
- int64_t seq_len = x_abs_offset[i] - x_abs_offset[i - 1 ];
59
- max_len = max_len < seq_len ? seq_len : max_len;
49
+ const auto & x_lod = x_var->Get <LoDTensor>().lod ();
50
+ PADDLE_ENFORCE (!x_lod.empty (), " The Input(X) must hold lod info." );
51
+ const auto & x_lod_0 = x_lod[0 ];
52
+ PADDLE_ENFORCE_GE (x_lod_0.size (), 2 ,
53
+ " The Input(X)'s lod info is corrupted." );
54
+ PADDLE_ENFORCE_EQ (
55
+ x_dims[0 ], static_cast <int64_t >(x_lod_0.back ()),
56
+ " The Input(X)'s lod info mismatches the actual tensor shape." );
57
+
58
+ int seq_num = x_lod_0.size () - 1 ;
59
+ int max_seq_len = math::MaximumSequenceLength (x_lod_0);
60
+ int padded_length = ctx->Attrs ().Get <int >(" padded_length" );
61
+ if (padded_length == -1 ) {
62
+ padded_length = max_seq_len;
60
63
}
64
+ PADDLE_ENFORCE_GE (padded_length, max_seq_len,
65
+ " The Attr(padded_length) must be -1 or an int greater "
66
+ " than the length of the longest original sequence." );
67
+ batch_dim_size = padded_length * seq_num;
61
68
} else {
69
+ // compile time
62
70
framework::VarDesc* x_desc =
63
71
boost::get<framework::VarDesc*>(ctx->GetInputVarPtrs (" X" )[0 ]);
64
- x_lod_size = x_desc->GetLoDLevel ();
72
+ PADDLE_ENFORCE_GE ( x_desc->GetLoDLevel (), 1 );
65
73
}
66
74
67
- PADDLE_ENFORCE (lod_level >= 0 && lod_level < x_lod_size,
68
- " Invalid `lod_level` which should be at least 0 and less "
69
- " than maximum lod level of `X`" );
70
-
71
- ctx->SetOutputDim (" Out" , {seq_num, max_len, x_dims[1 ]});
72
- }
73
-
74
- protected:
75
- framework::OpKernelType GetExpectedKernelType (
76
- const framework::ExecutionContext& ctx) const override {
77
- return framework::OpKernelType (
78
- framework::ToDataType (ctx.Input <framework::LoDTensor>(" X" )->type ()),
79
- ctx.device_context ());
75
+ auto out_dims = x_dims;
76
+ out_dims[0 ] = batch_dim_size;
77
+ ctx->SetOutputDim (" Out" , out_dims);
80
78
}
81
79
};
82
80
83
81
class SequencePadOpMaker : public framework ::OpProtoAndCheckerMaker {
84
82
public:
85
- SequencePadOpMaker (OpProto* proto, OpAttrChecker* op_checker)
86
- : OpProtoAndCheckerMaker(proto, op_checker) {
83
+ void Make () override {
87
84
AddInput (" X" ,
88
85
" (LoDTensor, default LoDTensor<float>) Input variable which "
89
- " should contain lod information. Length of each sequence would "
90
- " be computed from the most bottom level lod." );
91
- AddOutput (" Out" ,
92
- " (Tensor) Output variable which would be a common tensor "
93
- " without lod. Each sequence would be padded to the maximum "
94
- " length." );
95
- AddAttr<float >(" lod_level" ,
96
- " (int, default 0) Specify which level lod to referred to." );
97
- AddAttr<float >(" pad_value" ,
98
- " (float, default 0.0) Specify which value to be padded to "
99
- " the end of each sequence." );
86
+ " should contain lod information." );
87
+ AddInput (" PadValue" ,
88
+ " (LoDTensor), this Tensor holds values that will be fill into "
89
+ " padded steps. It can be a scalar or a tensor whose shape equals "
90
+ " to time steps in sequences. If it's a scalar, it will be "
91
+ " automatically broadcasted to the shape of time step." );
92
+ AddOutput (
93
+ " Out" ,
94
+ " (LoDTensor) The output vairable, which contains padded sequences." );
95
+ AddAttr<int >(
96
+ " padded_length" ,
97
+ " The length of padded sequences. It can be setted to -1 or "
98
+ " any positive int. When it is -1, all sequences will be padded up to "
99
+ " the length of the longest one among them; when it a certain positive "
100
+ " value, it must be greater than the length of the longest original "
101
+ " sequence." )
102
+ .SetDefault (-1 );
100
103
AddComment (R"DOC(
101
104
102
105
)DOC" );
0 commit comments