Skip to content

Commit 8efd087

Browse files
author
chengduo
authored
Merge pull request #5187 from chengduoZH/fix_pool_op
fix pool op
2 parents 0049ce0 + 6bdf5c1 commit 8efd087

File tree

8 files changed

+109
-75
lines changed

8 files changed

+109
-75
lines changed

paddle/operators/pool_cudnn_op.cu

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ class PoolCudnnOpKernel : public framework::OpKernel<T> {
4343
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
4444
if (ctx.Attr<bool>("globalPooling")) {
4545
for (size_t i = 0; i < ksize.size(); ++i) {
46+
paddings[i] = 0;
4647
ksize[i] = static_cast<int>(input->dims()[i + 2]);
4748
}
4849
}
@@ -97,8 +98,10 @@ class PoolCudnnGradOpKernel : public framework::OpKernel<T> {
9798
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
9899

99100
if (ctx.Attr<bool>("globalPooling")) {
100-
for (size_t i = 0; i < ksize.size(); ++i)
101+
for (size_t i = 0; i < ksize.size(); ++i) {
102+
paddings[i] = 0;
101103
ksize[i] = static_cast<int>(input->dims()[i + 2]);
104+
}
102105
}
103106

104107
const T *input_data = input->data<T>();

paddle/operators/pool_op.cc

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,10 @@ void PoolOp::InferShape(framework::InferShapeContext *ctx) const {
3939

4040
if (ctx->Attrs().Get<bool>("globalPooling")) {
4141
ksize.resize(static_cast<size_t>(in_x_dims.size()) - 2);
42-
for (size_t i = 0; i < ksize.size(); ++i)
42+
for (size_t i = 0; i < ksize.size(); ++i) {
43+
paddings[i] = 0;
4344
ksize[i] = static_cast<int>(in_x_dims[i + 2]);
45+
}
4446
}
4547

4648
PADDLE_ENFORCE(in_x_dims.size() - ksize.size() == 2U,
@@ -84,15 +86,16 @@ Pool2dOpMaker::Pool2dOpMaker(framework::OpProto *proto,
8486
"(string), pooling type, can be \"max\" for max-pooling "
8587
"and \"avg\" for average-pooling.")
8688
.InEnum({"max", "avg"});
87-
AddAttr<std::vector<int>>(
88-
"ksize",
89-
"(vector ), the pooling window size(height, width) of pooling operator."
90-
"If globalPooling = true, ksize is ignored and need not be "
91-
"specified."); // TODO(Chengduo): Add checker. (Currently,
89+
AddAttr<std::vector<int>>("ksize",
90+
"(vector ), the pooling window size(height, width) "
91+
"of pooling operator."
92+
"If globalPooling = true, ksize and paddings will "
93+
"be ignored."); // TODO(Chengduo): Add checker.
94+
// (Currently,
9295
// TypedAttrChecker don't support vector type.)
9396
AddAttr<bool>("globalPooling",
9497
"(bool default: false), whether to use the global pooling."
95-
"If globalPooling = true, ksize is ignored.")
98+
"If globalPooling = true, ksize and paddings will be ignored.")
9699
.SetDefault(false);
97100
AddAttr<std::vector<int>>(
98101
"strides",
@@ -101,7 +104,8 @@ Pool2dOpMaker::Pool2dOpMaker(framework::OpProto *proto,
101104
// TypedAttrChecker don't support vector type.)
102105
AddAttr<std::vector<int>>(
103106
"paddings",
104-
"(vector defalut:{0,0}), paddings(height, width) of pooling operator.")
107+
"(vector defalut:{0,0}), paddings(height, width) of pooling operator."
108+
"If globalPooling = true, paddings and ksize will be ignored.")
105109
.SetDefault({0, 0}); // TODO(Chengduo): Add checker. (Currently,
106110
// TypedAttrChecker don't support vector type.)
107111

@@ -145,25 +149,28 @@ Pool3dOpMaker::Pool3dOpMaker(framework::OpProto *proto,
145149
"(string), pooling type, can be \"max\" for max-pooling "
146150
"and \"avg\" for average-pooling.")
147151
.InEnum({"max", "avg"});
148-
AddAttr<std::vector<int>>(
149-
"ksize",
150-
"(vector ), the pooling window size(depth, height, width) of pooling "
151-
"operator."
152-
"If globalPooling = true, ksize is ignored and need not be "
153-
"specified."); // TODO(Chengduo): Add checker. (Currently,
154-
// TypedAttrChecker don't support vector type.)
152+
AddAttr<std::vector<int>>("ksize",
153+
"(vector ), the pooling window size(depth, height, "
154+
"width) of pooling "
155+
"operator."
156+
"If globalPooling = true, ksize and paddings wille "
157+
"be ignored."); // TODO(Chengduo): Add checker.
158+
// (Currently,
159+
// TypedAttrChecker don't support vector type.)
155160
AddAttr<bool>("globalPooling",
156161
"(bool default: false), whether to use the global pooling."
157-
"If globalPooling = true, ksize is ignored.")
162+
"If globalPooling = true, ksize and paddings wille be ignored.")
158163
.SetDefault(false);
159164
AddAttr<std::vector<int>>("strides",
160165
"(vector, default:{1,1,1}), strides(depth, height, "
161166
"width) of pooling operator.")
162167
.SetDefault({1, 1, 1}); // TODO(Chengduo): Add checker. (Currently,
163168
// TypedAttrChecker don't support vector type.)
164-
AddAttr<std::vector<int>>("paddings",
165-
"(vector defalut:{0,0,0}), paddings(depth, height, "
166-
"width) of pooling operator.")
169+
AddAttr<std::vector<int>>(
170+
"paddings",
171+
"(vector defalut:{0,0,0}), paddings(depth, height, "
172+
"width) of pooling operator."
173+
"If globalPooling = true, ksize and paddings wille be ignored.")
167174
.SetDefault({0, 0, 0}); // TODO(Chengduo): Add checker. (Currently,
168175
// TypedAttrChecker don't support vector type.)
169176

paddle/operators/pool_op.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ class PoolKernel : public framework::OpKernel<T> {
6363
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
6464
if (context.Attr<bool>("globalPooling")) {
6565
for (size_t i = 0; i < ksize.size(); ++i) {
66+
paddings[i] = 0;
6667
ksize[i] = static_cast<int>(in_x->dims()[i + 2]);
6768
}
6869
}
@@ -103,6 +104,7 @@ class PoolKernel : public framework::OpKernel<T> {
103104
paddings, pool_process);
104105
}
105106
} break;
107+
default: { PADDLE_THROW("Pool op only supports 2D and 3D input."); }
106108
}
107109
}
108110
};
@@ -123,8 +125,10 @@ class PoolGradKernel : public framework::OpKernel<T> {
123125
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
124126

125127
if (context.Attr<bool>("globalPooling")) {
126-
for (size_t i = 0; i < ksize.size(); ++i)
128+
for (size_t i = 0; i < ksize.size(); ++i) {
129+
paddings[i] = 0;
127130
ksize[i] = static_cast<int>(in_x->dims()[i + 2]);
131+
}
128132
}
129133

130134
if (in_x_grad) {
@@ -164,6 +168,7 @@ class PoolGradKernel : public framework::OpKernel<T> {
164168
*out_grad, ksize, strides, paddings, pool_process);
165169
}
166170
} break;
171+
default: { PADDLE_THROW("Pool op only supports 2D and 3D input."); }
167172
}
168173
}
169174
}

paddle/operators/pool_with_index_op.cc

Lines changed: 37 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,10 @@ class MaxPoolWithIndexOp : public framework::OperatorWithKernel {
4646

4747
if (ctx->Attrs().Get<bool>("globalPooling")) {
4848
ksize.resize(static_cast<size_t>(in_x_dims.size()) - 2);
49-
for (size_t i = 0; i < ksize.size(); ++i)
49+
for (size_t i = 0; i < ksize.size(); ++i) {
50+
paddings[i] = 0;
5051
ksize[i] = static_cast<int>(in_x_dims[i + 2]);
52+
}
5153
}
5254

5355
PADDLE_ENFORCE(in_x_dims.size() - ksize.size() == 2U,
@@ -87,31 +89,33 @@ class MaxPool2dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker {
8789
: OpProtoAndCheckerMaker(proto, op_checker) {
8890
AddInput(
8991
"X",
90-
"(Tensor) The input tensor of pooling operator. "
92+
"(Tensor), the input tensor of pooling operator. "
9193
"The format of input tensor is NCHW. Where N is batch size, C is the "
9294
"number of channels, H and W is the height and width of image.");
9395
AddOutput("Out",
94-
"(Tensor) The output tensor of pooling operator."
96+
"(Tensor), the output tensor of pooling operator."
9597
"The format of output tensor is also NCHW."
9698
"Where N is batch size, C is "
9799
"the number of channels, H and W is the height and "
98100
"width of image.");
99101
AddOutput("Mask",
100-
"(Tensor) The Mask tensor of pooling operator."
102+
"(Tensor), the Mask tensor of pooling operator."
101103
"The format of output tensor is also NCHW."
102104
"Where N is batch size, C is the number of channels, H and W "
103105
"is the height and width of image."
104106
"The value in it is the index in current feature map");
105107

106-
AddAttr<std::vector<int>>(
107-
"ksize",
108-
"(vector ), the pooling window size(height, width) of pooling operator."
109-
"If globalPooling = true, ksize is ignored and need not be "
110-
"specified."); // TODO(Chengduo): Add checker. (Currently,
108+
AddAttr<std::vector<int>>("ksize",
109+
"(vector ), the pooling window size(height, "
110+
"width) of pooling operator."
111+
"If globalPooling = true, ksize and paddings "
112+
"will be ignored."); // TODO(Chengduo): Add
113+
// checker. (Currently,
111114
// TypedAttrChecker don't support vector type.)
112-
AddAttr<bool>("globalPooling",
113-
"(bool default: false), whether to use the global pooling."
114-
"If globalPooling = true, ksize is ignored.")
115+
AddAttr<bool>(
116+
"globalPooling",
117+
"(bool default: false), whether to use the global pooling."
118+
"If globalPooling = true, ksize and paddings will be ignored.")
115119
.SetDefault(false);
116120
AddAttr<std::vector<int>>(
117121
"strides",
@@ -120,7 +124,8 @@ class MaxPool2dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker {
120124
// TypedAttrChecker don't support vector type.)
121125
AddAttr<std::vector<int>>(
122126
"paddings",
123-
"(vector defalut:{0,0}), paddings(height, width) of pooling operator.")
127+
"(vector defalut:{0, 0}), paddings(height, width) of pooling operator."
128+
"If globalPooling = true, paddings and will be ignored.")
124129
.SetDefault({0, 0}); // TODO(Chengduo): Add checker. (Currently,
125130
// TypedAttrChecker don't support vector type.)
126131

@@ -153,42 +158,46 @@ class MaxPool3dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker {
153158
: OpProtoAndCheckerMaker(proto, op_checker) {
154159
AddInput(
155160
"X",
156-
"(Tensor) The input tensor of pooling operator. "
161+
"(Tensor), the input tensor of pooling operator. "
157162
"The format of input tensor is NCDHW. Where N is batch size, C is "
158163
"the number of channels, D, H and W is the depth, height and width of "
159164
"image.");
160165
AddOutput("Out",
161-
"(Tensor) The output tensor of pooling operator."
166+
"(Tensor), the output tensor of pooling operator."
162167
"The format of output tensor is also NCDHW."
163168
"Where N is batch size, C is "
164169
"the number of channels, D, H and W is the depth, height and "
165170
"width of image.");
166171
AddOutput("Mask",
167-
"(Tensor) The Mask tensor of pooling operator."
172+
"(Tensor), the Mask tensor of pooling operator."
168173
"The format of output tensor is also NCDHW."
169174
"Where N is batch size, C is the number of channels, D, H and W "
170175
"is the depth, height and width of image."
171176
"The value in it is the index in current feature map");
172177

173-
AddAttr<std::vector<int>>(
174-
"ksize",
175-
"(vector ), the pooling window size(depth, height, width) of pooling "
176-
"operator."
177-
"If globalPooling = true, ksize is ignored and need not be "
178-
"specified."); // TODO(Chengduo): Add checker. (Currently,
178+
AddAttr<std::vector<int>>("ksize",
179+
"(vector), the pooling window size(depth, "
180+
"height, width) of pooling "
181+
"operator."
182+
"If globalPooling = true, ksize and paddings "
183+
"will be ignored."); // TODO(Chengduo): Add
184+
// checker. (Currently,
179185
// TypedAttrChecker don't support vector type.)
180-
AddAttr<bool>("globalPooling",
181-
"(bool default: false), whether to use the global pooling."
182-
"If globalPooling = true, ksize is ignored.")
186+
AddAttr<bool>(
187+
"globalPooling",
188+
"(bool default: false), whether to use the global pooling."
189+
"If globalPooling = true, ksize and paddings will be ignored.")
183190
.SetDefault(false);
184191
AddAttr<std::vector<int>>("strides",
185192
"(vector, default:{1,1,1}), strides(depth, "
186193
"height, width) of pooling operator.")
187194
.SetDefault({1, 1, 1}); // TODO(Chengduo): Add checker. (Currently,
188195
// TypedAttrChecker don't support vector type.)
189-
AddAttr<std::vector<int>>("paddings",
190-
"(vector defalut:{0,0,0}), paddings(depth, "
191-
"height, width) of pooling operator.")
196+
AddAttr<std::vector<int>>(
197+
"paddings",
198+
"(vector defalut:{0,0,0}), paddings(depth, "
199+
"height, width) of pooling operator."
200+
"If globalPooling = true, paddings and ksize will be ignored.")
192201
.SetDefault({0, 0, 0}); // TODO(Chengduo): Add checker. (Currently,
193202
// TypedAttrChecker don't support vector type.)
194203

paddle/operators/pool_with_index_op.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ class MaxPoolWithIndexKernel : public framework::OpKernel<T> {
3737
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
3838
if (context.Attr<bool>("globalPooling")) {
3939
for (size_t i = 0; i < ksize.size(); ++i) {
40+
paddings[i] = 0;
4041
ksize[i] = static_cast<int>(in_x->dims()[i + 2]);
4142
}
4243
}
@@ -54,6 +55,7 @@ class MaxPoolWithIndexKernel : public framework::OpKernel<T> {
5455
pool3d_forward(context.device_context(), *in_x, *out, *mask, ksize,
5556
strides, paddings);
5657
} break;
58+
default: { PADDLE_THROW("Pool op only supports 2D and 3D input."); }
5759
}
5860
}
5961
};
@@ -72,6 +74,7 @@ class MaxPoolWithIndexGradKernel : public framework::OpKernel<T> {
7274
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
7375
if (context.Attr<bool>("globalPooling")) {
7476
for (size_t i = 0; i < ksize.size(); ++i) {
77+
paddings[i] = 0;
7578
ksize[i] = static_cast<int>(in_x_grad->dims()[i + 2]);
7679
}
7780
}
@@ -95,6 +98,7 @@ class MaxPoolWithIndexGradKernel : public framework::OpKernel<T> {
9598
pool3d_backward(context.device_context(), *in_x_grad, *out_grad,
9699
*mask, ksize, strides, paddings);
97100
} break;
101+
default: { PADDLE_THROW("Pool op only supports 2D and 3D input."); }
98102
}
99103
}
100104
}

python/paddle/v2/framework/tests/test_pool2d_op.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,12 @@ def setUp(self):
4949
self.init_test_case()
5050
self.init_op_type()
5151
self.init_pool_type()
52+
if self.global_pool:
53+
self.paddings = [0 for _ in range(len(self.paddings))]
5254
input = np.random.random(self.shape).astype("float32")
5355
output = self.pool2D_forward_naive(input, self.ksize, self.strides,
54-
self.paddings, self.global_pool)
56+
self.paddings,
57+
self.global_pool).astype("float32")
5558
self.inputs = {'X': input}
5659

5760
self.attrs = {

python/paddle/v2/framework/tests/test_pool3d_op.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,13 @@ def avg_pool3D_forward_naive(x, ksize, strides, paddings=[0, 0], global_pool=0):
5454

5555
class TestPool3d_Op(OpTest):
5656
def setUp(self):
57-
self.initTestCase()
57+
self.init_test_case()
58+
if self.global_pool:
59+
self.paddings = [0 for _ in range(len(self.paddings))]
5860
input = np.random.random(self.shape).astype("float32")
5961
output = self.pool3D_forward_naive(input, self.ksize, self.strides,
60-
self.paddings, self.global_pool)
62+
self.paddings,
63+
self.global_pool).astype("float32")
6164
self.inputs = {'X': input}
6265

6366
self.attrs = {
@@ -77,7 +80,7 @@ def test_check_grad(self):
7780
if self.pool_type != "max":
7881
self.check_grad(set(['X']), 'Out', max_relative_error=0.07)
7982

80-
def initTestCase(self):
83+
def init_test_case(self):
8184
self.global_pool = True
8285
self.op_type = "pool3d"
8386
self.pool_type = "avg"
@@ -89,7 +92,7 @@ def initTestCase(self):
8992

9093

9194
class TestCase1(TestPool3d_Op):
92-
def initTestCase(self):
95+
def init_test_case(self):
9396
self.global_pool = False
9497
self.op_type = "pool3d"
9598
self.pool_type = "avg"
@@ -101,7 +104,7 @@ def initTestCase(self):
101104

102105

103106
class TestCase2(TestPool3d_Op):
104-
def initTestCase(self):
107+
def init_test_case(self):
105108
self.global_pool = False
106109
self.op_type = "pool3d"
107110
self.pool_type = "avg"
@@ -113,7 +116,7 @@ def initTestCase(self):
113116

114117

115118
class TestCase3(TestPool3d_Op):
116-
def initTestCase(self):
119+
def init_test_case(self):
117120
self.global_pool = True
118121
self.op_type = "pool3d"
119122
self.pool_type = "max"
@@ -125,7 +128,7 @@ def initTestCase(self):
125128

126129

127130
class TestCase4(TestPool3d_Op):
128-
def initTestCase(self):
131+
def init_test_case(self):
129132
self.global_pool = False
130133
self.op_type = "pool3d"
131134
self.pool_type = "max"
@@ -137,7 +140,7 @@ def initTestCase(self):
137140

138141

139142
class TestCase5(TestPool3d_Op):
140-
def initTestCase(self):
143+
def init_test_case(self):
141144
self.global_pool = False
142145
self.op_type = "pool3d"
143146
self.pool_type = "max"

0 commit comments

Comments
 (0)