@@ -24,12 +24,50 @@ class BatchSizeLikeOp : public framework::OperatorWithKernel {
24
24
public:
25
25
using framework::OperatorWithKernel::OperatorWithKernel;
26
26
27
- void InferShape (framework::InferShapeContext *ctx) const override ;
27
+ void InferShape (framework::InferShapeContext *ctx) const override {
28
+ PADDLE_ENFORCE (ctx->HasInput (" Input" ),
29
+ " Input(Input) of %s should not be null." , Type ());
30
+ PADDLE_ENFORCE (ctx->HasOutput (" Out" ),
31
+ " Output(Out) of %s should not be null." , Type ());
32
+
33
+ auto &shape = ctx->Attrs ().Get <std::vector<int >>(" shape" );
34
+ PADDLE_ENFORCE_GT (shape.size (), 0 );
35
+ std::vector<int64_t > shape_int64 (shape.size (), 0 );
36
+ std::transform (shape.begin (), shape.end (), shape_int64.begin (),
37
+ [](int a) { return static_cast <int64_t >(a); });
38
+ auto output_dim = framework::make_ddim (shape_int64);
39
+
40
+ int input_dim_idx = ctx->Attrs ().Get <int >(" input_dim_idx" );
41
+ PADDLE_ENFORCE_GE (input_dim_idx, 0 );
42
+ PADDLE_ENFORCE_GT (ctx->GetInputDim (" Input" ).size (), input_dim_idx);
43
+
44
+ int output_dim_idx = ctx->Attrs ().Get <int >(" output_dim_idx" );
45
+ PADDLE_ENFORCE_GE (output_dim_idx, 0 );
46
+ PADDLE_ENFORCE_GT (static_cast <int >(shape.size ()), output_dim_idx);
47
+
48
+ output_dim[output_dim_idx] = ctx->GetInputDim (" Input" )[input_dim_idx];
49
+ ctx->SetOutputDim (" Out" , output_dim);
50
+ }
28
51
};
29
52
30
53
class BatchSizeLikeOpMaker : public framework ::OpProtoAndCheckerMaker {
31
54
public:
32
- BatchSizeLikeOpMaker (OpProto *proto, OpAttrChecker *op_checker);
55
+ BatchSizeLikeOpMaker (OpProto *proto, OpAttrChecker *op_checker)
56
+ : framework::OpProtoAndCheckerMaker(proto, op_checker) {
57
+ AddInput (" Input" ,
58
+ " (Tensor) Tensor "
59
+ " whose input_dim_idx'th dimension specifies the batch_size" );
60
+ AddOutput (" Out" ,
61
+ " (Tensor) Tensor of specified shape will be filled "
62
+ " with the specified value" );
63
+ AddAttr<std::vector<int >>(" shape" , " (vector<int>) The shape of the output" );
64
+ AddAttr<int >(" input_dim_idx" ,
65
+ " (int, default 0) The index of input's batch size dimension" )
66
+ .SetDefault (0 );
67
+ AddAttr<int >(" output_dim_idx" ,
68
+ " (int, default 0) The index of output's batch size dimension" )
69
+ .SetDefault (0 );
70
+ }
33
71
};
34
72
35
73
} // namespace operators
0 commit comments