@@ -17,90 +17,66 @@ limitations under the License. */
17
17
namespace paddle {
18
18
namespace operators {
19
19
20
- class ReshapeOp : public framework ::OperatorWithKernel {
21
- public:
22
- ReshapeOp (const std::string &type, const framework::VariableNameMap &inputs,
23
- const framework::VariableNameMap &outputs,
24
- const framework::AttributeMap &attrs)
25
- : OperatorWithKernel(type, inputs, outputs, attrs) {}
26
-
27
- void InferShape (framework::InferShapeContext *ctx) const override {
28
- // input check
29
- PADDLE_ENFORCE (ctx->HasInput (" X" ),
30
- " Input(X) of ReshapeOp should not be null." );
31
- PADDLE_ENFORCE (ctx->HasOutput (" Out" ),
32
- " Output(Out) of ReshapeOp should not be null." );
33
-
34
- auto shape = ctx->Attrs ().Get <std::vector<int >>(" shape" );
35
- PADDLE_ENFORCE (shape.size () > 0 , " Attr(shape) shouldn't be empty." );
36
- auto x_dims = ctx->GetInputDim (" X" );
37
-
38
- std::vector<size_t > neg_dims_idx;
39
- // set some dimension to -1 if it is unknown
40
- const int unknown_size = -1 ;
41
- for (size_t i = 0 ; i < shape.size (); ++i) {
42
- PADDLE_ENFORCE (shape[i] > 0 || shape[i] == unknown_size,
43
- " Each dimension of Attr(shape) must be positive or %d." ,
44
- unknown_size);
45
- if (shape[i] == unknown_size) {
46
- neg_dims_idx.push_back (i);
47
- PADDLE_ENFORCE (neg_dims_idx.size () <= 1 ,
48
- " Only one dimension of Attr(shape) can be unknown." );
49
- }
50
- }
51
-
52
- int64_t capacity =
53
- std::accumulate (shape.begin (), shape.end (), 1 , std::multiplies<int >());
54
- int64_t in_size = framework::product (x_dims);
55
- if (neg_dims_idx.size () == 1 ) {
56
- // dim infer
57
- shape[neg_dims_idx[0 ]] = in_size / (-capacity);
58
- // recalculate capacity
59
- capacity = shape[neg_dims_idx[0 ]] * (-capacity);
60
- }
61
- // capacity check
62
- PADDLE_ENFORCE (capacity == in_size,
63
- " The size of Input(X) mismatches with Attr(shape)." );
64
- // resize output
65
- std::vector<int64_t > shape_int64 (shape.size (), 0 );
66
- std::transform (shape.begin (), shape.end (), shape_int64.begin (),
67
- [](int a) { return static_cast <int64_t >(a); });
68
- auto out_dims = framework::make_ddim (shape_int64);
69
- ctx->SetOutputDim (" Out" , out_dims);
70
- if (shape[0 ] == x_dims[0 ]) {
71
- // Only pass LoD when the first dimension is equal between
72
- // output and input.
73
- ctx->ShareLoD (" X" , /* ->*/ " Out" );
74
- }
75
- }
76
- };
77
-
78
20
class ReshapeOpMaker : public framework ::OpProtoAndCheckerMaker {
79
21
public:
80
22
ReshapeOpMaker (OpProto *proto, OpAttrChecker *op_checker)
81
23
: OpProtoAndCheckerMaker(proto, op_checker) {
82
- AddInput (" X" , " The input tensor of reshape operator." );
83
- AddOutput (" Out" , " The output tensor of reshape operator." );
84
- AddAttr<std::vector<int >>(" shape" ,
85
- " (vector<int>) "
86
- " Target shape of reshape operator." );
24
+ AddInput (" X" , " (Tensor). The input tensor of reshape operator." );
25
+ AddInput (" Shape" ,
26
+ " (Tensor<int32>, optional). If provided, reshape according to "
27
+ " this given shape. That is to say it has a higher priority than "
28
+ " the shape attribute, while the shape attribute still should be "
29
+ " set correctly to gurantee shape inference in compile time." )
30
+ .AsDispensable ();
31
+ AddOutput (" Out" , " (Tensor). The output tensor of reshape operator." );
32
+ AddAttr<std::vector<int >>(
33
+ " shape" , " (std::vector<int>) Target shape of reshape operator." );
87
34
AddAttr<bool >(" inplace" ,
88
- " Change the source tensor's shape without copy memory." )
89
- .SetDefault (true );
35
+ " (default: false) Change the source tensor's shape without "
36
+ " memory copy. When Attr(inplace) is set true, the output "
37
+ " tensor shares memory with Input(X), otherwise, a new output "
38
+ " tensor is created, and its data are copied from Input(x)." )
39
+ .SetDefault (false );
90
40
AddComment (R"DOC(
91
41
Reshape Operator.
92
42
93
- Reshape Input(X) into the shape specified by Attr(shape).
43
+ Reshape Input(X) into the shape specified by Attr(shape) or Input(Shape). The
44
+ data in Input(X) are unchanged.
45
+
46
+ Examples:
94
47
95
- An example:
96
- Given a 2-D tensor X with 2 rows and 2 columns : [[1, 2], [3, 4]]
48
+ 1. Given a 3-D tensor Input(X) with a shape [2, 4, 6], and the target shape
49
+ specified by Attr(shape) is [6, 8], the reshape operator will transform Input(X)
50
+ into a 2-D tensor with shape [6, 8] and leaving Input(X)'s data unchanged.
97
51
98
- and target shape = [1, 4], the reshape operator will transform
99
- the tensor X into a 2-D tensor: [[1, 2, 3, 4]]
52
+ 2. Given a 3-D tensor Input(X) with a shape [2, 4, 6], and the target shape
53
+ specified by Attr(shape) is [2, 3, -1, 2], the reshape operator will transform
54
+ Input(X) into a 4-D tensor with shape [2, 3, 4, 2] and leaving Input(X)'s data
55
+ unchanged. In this case, one and only dimension of Attr(shape) can be set to -1,
56
+ the value of this dimension is inferred from the total element number of
57
+ Input(X) and remaining dimensions.
58
+
59
+ 3. Given a 3-D tensor Input(X) with a shape [2, 4, 6], and the target shape
60
+ specified by Attr(shape) is [-1, 0, 3, 2], the reshape operator will transform
61
+ Input(X) into a 4-D tensor with shape [2, 4, 3, 2] and leaving Input(X)'s data
62
+ unchanged. In this case, besides -1, 0 means the actual dimension value is going
63
+ to be copied from the corresponding dimension of Input(X).
64
+
65
+ Note:
66
+
67
+ 1. One and only one dimension in Attr(shape) can be set -1. In this case,
68
+ the actual dimension value will be infered from the total element number of
69
+ Input(X) and remaining dimensions.
70
+
71
+ 2. More than one dimensions in Attr(shape) can be set to 0, which means the real
72
+ dimension value will be copied from Input(X) at runtime. Note that the index of
73
+ 0 can not exceed Rank(X). For example, Input(X) is a 3-D tensor with shape
74
+ [2, 3, 4], Attr(shape) = [2, 3, 2, 0] is an invalid input.
75
+
76
+ 3. Input(Shape) has a higher priority than Attr(shape) if it is provided, while
77
+ Attr(shape) still should be set correctly to gurantee shape inference in
78
+ compile-time.
100
79
101
- One dimension in the target shape can be set -1, representing that its
102
- size is unknown. In this case, the real dimension will be infered from
103
- the original shape of Input(X) and other dimensions in the target shape.
104
80
)DOC" );
105
81
}
106
82
};
@@ -119,6 +95,14 @@ class ReshapeGradOp : public framework::OperatorWithKernel {
119
95
" Input(Out@GRAD) shouldn't be null." );
120
96
ctx->SetOutputDim (framework::GradVarName (" X" ), ctx->GetInputDim (" X" ));
121
97
}
98
+
99
+ protected:
100
+ framework::OpKernelType GetExpectedKernelType (
101
+ const framework::ExecutionContext &ctx) const override {
102
+ return framework::OpKernelType (
103
+ framework::ToDataType (ctx.Input <framework::LoDTensor>(" X" )->type ()),
104
+ ctx.device_context ());
105
+ }
122
106
};
123
107
124
108
} // namespace operators
0 commit comments