@@ -32,24 +32,22 @@ class Unpool2dOpMaker : public framework::OpProtoAndCheckerMaker {
32
32
" The format of input tensor is NCHW. Where N is batch size, C is the "
33
33
" number of channels, H and W is the height and width of feature." );
34
34
AddOutput (" Out" ,
35
- " (Tensor) The output tensor of unpool operator."
36
- " The format of output tensor is also NCHW."
37
- " Where N is batch size, C is "
38
- " the number of channels, H and W is the height and "
39
- " width of feature." );
35
+ " (Tensor) The output tensor of unpool operator."
36
+ " The format of output tensor is also NCHW."
37
+ " Where N is batch size, C is "
38
+ " the number of channels, H and W is the height and "
39
+ " width of feature." );
40
40
AddAttr<std::vector<int >>(
41
41
" ksize" ,
42
42
" (vector), the unpooling window size(height, width) "
43
43
" of unpooling operator." );
44
- AddAttr<std::vector<int >>(
45
- " strides" ,
46
- " (vector, default:{1, 1}), "
47
- " strides (height, width) of unpooling operator." )
44
+ AddAttr<std::vector<int >>(" strides" ,
45
+ " (vector, default:{1, 1}), "
46
+ " strides (height, width) of unpooling operator." )
48
47
.SetDefault ({1 , 1 });
49
- AddAttr<std::vector<int >>(
50
- " paddings" ,
51
- " (vector defalut:{0,0}), "
52
- " paddings (height, width) of unpooling operator." )
48
+ AddAttr<std::vector<int >>(" paddings" ,
49
+ " (vector defalut:{0,0}), "
50
+ " paddings (height, width) of unpooling operator." )
53
51
.SetDefault ({0 , 0 });
54
52
AddAttr<std::string>(
55
53
" unpooling_type" ,
@@ -75,71 +73,71 @@ int OutputSize(int input_size, int ksize, int padding, int stride) {
75
73
}
76
74
77
75
class UnpoolOp : public framework ::OperatorWithKernel {
78
- protected:
79
- framework::OpKernelType GetKernelType (
80
- const framework::ExecutionContext& ctx) const override {
76
+ protected:
77
+ framework::OpKernelType GetKernelType (
78
+ const framework::ExecutionContext& ctx) const override {
81
79
return framework::OpKernelType (
82
- framework::ToDataType (ctx.Input <framework::Tensor>(" X" )->type ()),
80
+ framework::ToDataType (ctx.Input <framework::Tensor>(" X" )->type ()),
83
81
ctx.device_context ());
84
82
}
85
83
86
- public:
87
- using framework::OperatorWithKernel::OperatorWithKernel;
88
- void InferShape (framework::InferShapeContext* ctx) const override {
89
- PADDLE_ENFORCE (ctx->HasInput (" X" ), " Input(X) of UnpoolOp"
84
+ public:
85
+ using framework::OperatorWithKernel::OperatorWithKernel;
86
+ void InferShape (framework::InferShapeContext* ctx) const override {
87
+ PADDLE_ENFORCE (ctx->HasInput (" X" ), " Input(X) of UnpoolOp"
90
88
" should not be null." );
91
- PADDLE_ENFORCE (ctx->HasInput (" Indices" ), " Input(Indices) of UnpoolOp"
89
+ PADDLE_ENFORCE (ctx->HasInput (" Indices" ), " Input(Indices) of UnpoolOp"
92
90
" should not be null." );
93
- PADDLE_ENFORCE (ctx->HasOutput (" Out" ),
91
+ PADDLE_ENFORCE (ctx->HasOutput (" Out" ),
94
92
" Output(Out) of UnpoolOp should not be null." );
95
- auto in_x_dims = ctx->GetInputDim (" X" );
96
- auto in_y_dims = ctx->GetInputDim (" Indices" );
97
- std::string unpooling_type =
93
+ auto in_x_dims = ctx->GetInputDim (" X" );
94
+ auto in_y_dims = ctx->GetInputDim (" Indices" );
95
+ std::string unpooling_type =
98
96
ctx->Attrs ().Get <std::string>(" unpooling_type" );
99
- std::vector<int > ksize = ctx->Attrs ().Get <std::vector<int >>(" ksize" );
100
- std::vector<int > strides = ctx->Attrs ().Get <std::vector<int >>(" strides" );
101
- std::vector<int > paddings =
97
+ std::vector<int > ksize = ctx->Attrs ().Get <std::vector<int >>(" ksize" );
98
+ std::vector<int > strides = ctx->Attrs ().Get <std::vector<int >>(" strides" );
99
+ std::vector<int > paddings =
102
100
ctx->Attrs ().Get <std::vector<int >>(" paddings" );
103
- PADDLE_ENFORCE (in_x_dims.size () == 4 ,
101
+ PADDLE_ENFORCE (in_x_dims.size () == 4 ,
104
102
" Unpooling intput must be of 4-dimensional." );
105
- PADDLE_ENFORCE_EQ (in_x_dims, in_y_dims);
106
- std::vector<int64_t > output_shape ({in_x_dims[0 ], in_x_dims[1 ]});
107
- for (size_t i = 0 ; i < ksize.size (); ++i) {
108
- output_shape.push_back (
109
- OutputSize (in_x_dims[i + 2 ], ksize[i], paddings[i], strides[i]));
110
- }
111
- ctx->SetOutputDim (" Out" , framework::make_ddim (output_shape));
112
- }
103
+ PADDLE_ENFORCE_EQ (in_x_dims, in_y_dims);
104
+ std::vector<int64_t > output_shape ({in_x_dims[0 ], in_x_dims[1 ]});
105
+ for (size_t i = 0 ; i < ksize.size (); ++i) {
106
+ output_shape.push_back (
107
+ OutputSize (in_x_dims[i + 2 ], ksize[i], paddings[i], strides[i]));
108
+ }
109
+ ctx->SetOutputDim (" Out" , framework::make_ddim (output_shape));
110
+ }
113
111
};
114
112
115
113
class UnpoolOpGrad : public framework ::OperatorWithKernel {
116
- protected:
117
- framework::OpKernelType GetKernelType (
118
- const framework::ExecutionContext& ctx) const override {
119
- return framework::OpKernelType (
120
- framework::ToDataType (ctx.Input <framework::Tensor>(" X" )->type ()),
121
- ctx.device_context ());
122
- }
114
+ protected:
115
+ framework::OpKernelType GetKernelType (
116
+ const framework::ExecutionContext& ctx) const override {
117
+ return framework::OpKernelType (
118
+ framework::ToDataType (ctx.Input <framework::Tensor>(" X" )->type ()),
119
+ ctx.device_context ());
120
+ }
123
121
124
- public:
125
- using framework::OperatorWithKernel::OperatorWithKernel;
126
- void InferShape (framework::InferShapeContext* ctx) const override {
127
- PADDLE_ENFORCE (ctx->HasInput (" X" ), " Input(X) must not be null." );
128
- PADDLE_ENFORCE (ctx->HasOutput (framework::GradVarName (" X" )),
122
+ public:
123
+ using framework::OperatorWithKernel::OperatorWithKernel;
124
+ void InferShape (framework::InferShapeContext* ctx) const override {
125
+ PADDLE_ENFORCE (ctx->HasInput (" X" ), " Input(X) must not be null." );
126
+ PADDLE_ENFORCE (ctx->HasOutput (framework::GradVarName (" X" )),
129
127
" Input(X@GRAD) should not be null." );
130
- ctx->SetOutputDim (framework::GradVarName (" X" ), ctx->GetInputDim (" X" ));
131
- }
128
+ ctx->SetOutputDim (framework::GradVarName (" X" ), ctx->GetInputDim (" X" ));
129
+ }
132
130
};
133
- } // namespace operators
134
- } // namespace paddle
131
+ } // namespace operators
132
+ } // namespace paddle
135
133
136
134
namespace ops = paddle::operators;
137
135
REGISTER_OP (unpool, ops::UnpoolOp, ops::Unpool2dOpMaker, unpool_grad,
138
136
ops::UnpoolOpGrad);
139
137
REGISTER_OP_CPU_KERNEL (
140
- unpool, ops::UnpoolKernel<paddle::platform::CPUPlace, float >,
141
- ops::UnpoolKernel<paddle::platform::CPUPlace, double >);
138
+ unpool, ops::UnpoolKernel<paddle::platform::CPUPlace, float >,
139
+ ops::UnpoolKernel<paddle::platform::CPUPlace, double >);
142
140
REGISTER_OP_CPU_KERNEL (
143
- unpool_grad, ops::UnpoolGradKernel<paddle::platform::CPUPlace, float >,
144
- ops::UnpoolGradKernel<paddle::platform::CPUPlace, double >);
141
+ unpool_grad, ops::UnpoolGradKernel<paddle::platform::CPUPlace, float >,
142
+ ops::UnpoolGradKernel<paddle::platform::CPUPlace, double >);
145
143
0 commit comments