Skip to content

Commit de80c56

Browse files
authored
Merge pull request #6100 from guoshengCS/enhance-include-pool
Enhance AvgPooling to support both include_mode and exclude_mode
2 parents c4599d3 + e135894 commit de80c56

File tree

14 files changed

+114
-39
lines changed

14 files changed

+114
-39
lines changed

paddle/cuda/include/hl_cnn.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ extern void hl_maxpool_backward(const int frameCnt,
116116
* @param[in] paddingW padding width.
117117
* @param[out] tgtData output data.
118118
* @param[in] tgtStride stride between output data samples.
119+
* @param[in] excludeMode whether to consider paddings for size.
119120
*
120121
*/
121122
extern void hl_avgpool_forward(const int frameCnt,
@@ -132,7 +133,8 @@ extern void hl_avgpool_forward(const int frameCnt,
132133
const int paddingH,
133134
const int paddingW,
134135
real* tgtData,
135-
const int tgtStride);
136+
const int tgtStride,
137+
bool excludeMode);
136138

137139
/**
138140
* @brief Maximum pool backward.
@@ -154,6 +156,7 @@ extern void hl_avgpool_forward(const int frameCnt,
154156
* @param[in] scaleB scale.
155157
* @param[out] backGrad output grad.
156158
* @param[in] outStride stride between output data samples.
159+
* @param[in] excludeMode whether to consider paddings for size.
157160
*
158161
*/
159162
extern void hl_avgpool_backward(const int frameCnt,
@@ -172,7 +175,8 @@ extern void hl_avgpool_backward(const int frameCnt,
172175
real scaleA,
173176
real scaleB,
174177
real* backGrad,
175-
const int outStride);
178+
const int outStride,
179+
bool excludeMode);
176180

177181
extern void hl_maxpool3D_forward(const int frameCnt,
178182
const real* inputData,

paddle/cuda/include/stub/hl_cnn_stub.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,8 @@ inline void hl_avgpool_forward(const int frameCnt,
6868
const int paddingH,
6969
const int paddingW,
7070
real* tgtData,
71-
const int tgtStride) {}
71+
const int tgtStride,
72+
const bool excludeMode) {}
7273

7374
inline void hl_avgpool_backward(const int frameCnt,
7475
const real* outGrad,
@@ -86,7 +87,8 @@ inline void hl_avgpool_backward(const int frameCnt,
8687
real scaleA,
8788
real scaleB,
8889
real* backGrad,
89-
const int outStride) {}
90+
const int outStride,
91+
const bool excludeMode) {}
9092

9193
inline void hl_maxpool3D_forward(const int frameCnt,
9294
const real* inputData,

paddle/cuda/src/hl_cuda_cnn.cu

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,8 @@ __global__ void KeAvgPoolForward(const int nthreads,
210210
const int padH,
211211
const int padW,
212212
real* tgtData,
213-
const int tgtStride) {
213+
const int tgtStride,
214+
const bool excludeMode) {
214215
int index = blockIdx.x * blockDim.x + threadIdx.x;
215216
if (index < nthreads) {
216217
int pw = index % pooledW;
@@ -224,7 +225,8 @@ __global__ void KeAvgPoolForward(const int nthreads,
224225
int wend = min(wstart + sizeX, width);
225226
hstart = max(hstart, 0);
226227
wstart = max(wstart, 0);
227-
int pool_size = (hend - hstart) * (wend - wstart);
228+
int poolSize =
229+
excludeMode ? (hend - hstart) * (wend - wstart) : sizeY * sizeX;
228230

229231
real aveval = 0;
230232
inputData += (frameNum * channels + c) * height * width;
@@ -235,7 +237,7 @@ __global__ void KeAvgPoolForward(const int nthreads,
235237
}
236238
int tgtIndex =
237239
index % (pooledW * pooledH * channels) + frameNum * tgtStride;
238-
tgtData[tgtIndex] = aveval / pool_size;
240+
tgtData[tgtIndex] = aveval / poolSize;
239241
}
240242
}
241243

@@ -253,7 +255,8 @@ void hl_avgpool_forward(const int frameCnt,
253255
const int paddingH,
254256
const int paddingW,
255257
real* tgtData,
256-
const int tgtStride) {
258+
const int tgtStride,
259+
const bool excludeMode) {
257260
int num_kernels = pooledH * pooledW * channels * frameCnt;
258261
int blocks = (num_kernels + 1024 - 1) / 1024;
259262
KeAvgPoolForward<<<blocks, 1024, 0, STREAM_DEFAULT>>>(num_kernels,
@@ -270,7 +273,8 @@ void hl_avgpool_forward(const int frameCnt,
270273
paddingH,
271274
paddingW,
272275
tgtData,
273-
tgtStride);
276+
tgtStride,
277+
excludeMode);
274278
CHECK_SYNC("hl_avgpool_forward failed");
275279
}
276280

@@ -290,7 +294,8 @@ __global__ void KeAvgPoolBackward(const int nthreads,
290294
real scaleA,
291295
real scaleB,
292296
real* tgtGrad,
293-
const int outStride) {
297+
const int outStride,
298+
const bool excludeMode) {
294299
int index = blockIdx.x * blockDim.x + threadIdx.x;
295300
if (index < nthreads) {
296301
int offsetW = index % width + padW;
@@ -314,8 +319,9 @@ __global__ void KeAvgPoolBackward(const int nthreads,
314319
int wstart = pw * strideW - padW;
315320
int wend = min(wstart + sizeX, width);
316321
wstart = max(wstart, 0);
317-
int poolsize = (hend - hstart) * (wend - wstart);
318-
gradient += outGrad[ph * pooledW + pw] / poolsize;
322+
int poolSize =
323+
excludeMode ? (hend - hstart) * (wend - wstart) : sizeY * sizeX;
324+
gradient += outGrad[ph * pooledW + pw] / poolSize;
319325
}
320326
}
321327
tgtGrad[index] = scaleB * tgtGrad[index] + scaleA * gradient;
@@ -338,7 +344,8 @@ void hl_avgpool_backward(const int frameCnt,
338344
real scaleA,
339345
real scaleB,
340346
real* backGrad,
341-
const int outStride) {
347+
const int outStride,
348+
const bool excludeMode) {
342349
int num_kernels = height * width * channels * frameCnt;
343350
int blocks = (num_kernels + 1024 - 1) / 1024;
344351

@@ -358,7 +365,8 @@ void hl_avgpool_backward(const int frameCnt,
358365
scaleA,
359366
scaleB,
360367
backGrad,
361-
outStride);
368+
outStride,
369+
excludeMode);
362370
CHECK_SYNC("hl_avgpool_backward failed");
363371
}
364372

paddle/gserver/layers/PoolLayer.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ bool PoolLayer::init(const LayerMap& layerMap,
4545
strideY_ = conf.has_stride_y() ? conf.stride_y() : conf.stride();
4646
confPaddingY_ = conf.has_padding_y() ? conf.padding_y() : conf.padding();
4747
outputY_ = conf.has_output_y() ? conf.output_y() : conf.output_x();
48+
49+
excludeMode_ = conf.has_exclude_mode() ? conf.exclude_mode() : true;
4850
return true;
4951
}
5052

paddle/gserver/layers/PoolLayer.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ class PoolLayer : public Layer {
3838

3939
std::string poolType_;
4040

41+
bool excludeMode_;
42+
4143
public:
4244
explicit PoolLayer(const LayerConfig& config) : Layer(config) {}
4345

paddle/gserver/layers/PoolProjection.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ PoolProjection::PoolProjection(const ProjectionConfig& config,
3636
strideY_ = conf.has_stride_y() ? conf.stride_y() : conf.stride();
3737
confPaddingY_ = conf.has_padding_y() ? conf.padding_y() : conf.padding();
3838
outputY_ = conf.has_output_y() ? conf.output_y() : conf.output_x();
39+
40+
excludeMode_ = conf.has_exclude_mode() ? conf.exclude_mode() : true;
3941
}
4042

4143
size_t PoolProjection::getSize() {
@@ -141,7 +143,8 @@ void AvgPoolProjection::forward() {
141143
outputY_,
142144
outputX_,
143145
confPaddingY_,
144-
confPadding_);
146+
confPadding_,
147+
excludeMode_);
145148
}
146149

147150
void AvgPoolProjection::backward(const UpdateCallback& callback) {
@@ -166,6 +169,7 @@ void AvgPoolProjection::backward(const UpdateCallback& callback) {
166169
1,
167170
1,
168171
confPaddingY_,
169-
confPadding_);
172+
confPadding_,
173+
excludeMode_);
170174
}
171175
} // namespace paddle

paddle/gserver/layers/PoolProjection.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ class PoolProjection : public Projection {
2828
int confPaddingY_, confPadding_;
2929
size_t channels_;
3030
std::string poolType_;
31+
bool excludeMode_;
3132

3233
public:
3334
PoolProjection(const ProjectionConfig& config,

paddle/gserver/tests/test_LayerGrad.cpp

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1211,14 +1211,18 @@ void setPoolConfig(TestConfig* config,
12111211
pool->set_output_y(oh);
12121212
}
12131213

1214-
void testPoolLayer(const string& poolType, bool trans, bool useGpu) {
1214+
void testPoolLayer(const string& poolType,
1215+
bool trans,
1216+
bool useGpu,
1217+
bool excludeMode = true) {
12151218
TestConfig config;
12161219
config.inputDefs.push_back({INPUT_DATA, "layer_0", 3136, 0});
12171220
LayerInputConfig* input = config.layerConfig.add_inputs();
12181221
PoolConfig* pool = input->mutable_pool_conf();
12191222

12201223
pool->set_img_size(14);
12211224
pool->set_img_size_y(14);
1225+
pool->set_exclude_mode(excludeMode);
12221226
setPoolConfig(&config, pool, poolType);
12231227
config.layerConfig.set_size(pool->output_x() * pool->output_y() *
12241228
pool->channels());
@@ -1250,16 +1254,26 @@ void testPoolLayer2(const string& poolType, bool trans, bool useGpu) {
12501254

12511255
TEST(Layer, PoolLayer) {
12521256
testPoolLayer("avg-projection", /* trans= */ false, /* useGpu= */ false);
1257+
testPoolLayer("avg-projection",
1258+
/* trans= */ false,
1259+
/* useGpu= */ false,
1260+
/* excludeMode= */ false);
12531261
testPoolLayer("max-projection", /* trans= */ false, /* useGpu= */ false);
12541262
testPoolLayer("max-pool-with-mask", /* trans= */ false, /* useGpu= */ false);
12551263

12561264
#ifdef PADDLE_WITH_CUDA
12571265
testPoolLayer("avg-projection", /* trans= */ false, /* useGpu= */ true);
1266+
testPoolLayer("avg-projection",
1267+
/* trans= */ false,
1268+
/* useGpu= */ true,
1269+
/* excludeMode= */ false);
12581270
testPoolLayer("max-projection", /* trans= */ false, /* useGpu= */ true);
12591271
testPoolLayer("cudnn-max-pool", /* trans= */ false, /* useGpu= */ true);
12601272
testPoolLayer("cudnn-avg-pool", /* trans= */ false, /* useGpu= */ true);
12611273
testPoolLayer2("cudnn-max-pool", /* trans= */ false, /* useGpu= */ true);
12621274
testPoolLayer2("cudnn-avg-pool", /* trans= */ false, /* useGpu= */ true);
1275+
testPoolLayer2(
1276+
"cudnn-avg-incl-pad-pool", /* trans= */ false, /* useGpu= */ true);
12631277
testPoolLayer("max-pool-with-mask", /* trans= */ false, /* useGpu= */ true);
12641278
#endif
12651279
}

paddle/math/Matrix.cpp

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1130,7 +1130,8 @@ void GpuMatrix::avgPoolForward(Matrix& inputMat,
11301130
size_t outputH,
11311131
size_t outputW,
11321132
size_t paddingH,
1133-
size_t paddingW) {
1133+
size_t paddingW,
1134+
bool excludeMode) {
11341135
CHECK(inputMat.useGpu_ == true) << "Matrix type are not equal";
11351136

11361137
real* inputData = inputMat.getData();
@@ -1153,7 +1154,8 @@ void GpuMatrix::avgPoolForward(Matrix& inputMat,
11531154
paddingH,
11541155
paddingW,
11551156
data_,
1156-
getStride());
1157+
getStride(),
1158+
excludeMode);
11571159
}
11581160

11591161
void GpuMatrix::avgPoolBackward(Matrix& outGrad,
@@ -1168,7 +1170,8 @@ void GpuMatrix::avgPoolBackward(Matrix& outGrad,
11681170
real scaleTargets,
11691171
real scaleOutput,
11701172
size_t paddingH,
1171-
size_t paddingW) {
1173+
size_t paddingW,
1174+
bool excludeMode) {
11721175
CHECK(outGrad.useGpu_ == true) << "Matrix type are not equal";
11731176

11741177
real* outDiff = outGrad.getData();
@@ -1194,7 +1197,8 @@ void GpuMatrix::avgPoolBackward(Matrix& outGrad,
11941197
scaleTargets,
11951198
scaleOutput,
11961199
data_,
1197-
outGrad.getStride());
1200+
outGrad.getStride(),
1201+
excludeMode);
11981202
}
11991203

12001204
void GpuMatrix::maxPool3DForward(Matrix& inputMat,
@@ -2136,7 +2140,8 @@ void CpuMatrix::avgPoolForward(Matrix& input,
21362140
size_t outputH,
21372141
size_t outputW,
21382142
size_t paddingH,
2139-
size_t paddingW) {
2143+
size_t paddingW,
2144+
bool excludeMode) {
21402145
// The main loop
21412146
size_t num = input.getHeight();
21422147
size_t inLength = imgSizeH * imgSizeW;
@@ -2165,7 +2170,8 @@ void CpuMatrix::avgPoolForward(Matrix& input,
21652170
tgtData[ph * outputW + pw] += inData[h * imgSizeW + w];
21662171
}
21672172
}
2168-
int poolSize = (hend - hstart) * (wend - wstart);
2173+
int poolSize =
2174+
excludeMode ? (hend - hstart) * (wend - wstart) : sizeY * sizeX;
21692175
CHECK(poolSize);
21702176
tgtData[ph * outputW + pw] /= poolSize;
21712177
}
@@ -2189,7 +2195,8 @@ void CpuMatrix::avgPoolBackward(Matrix& input,
21892195
real scaleTargets,
21902196
real scaleOutput,
21912197
size_t paddingH,
2192-
size_t paddingW) {
2198+
size_t paddingW,
2199+
bool excludeMode) {
21932200
size_t num = input.getHeight();
21942201
size_t channels = input.getWidth() / outputH / outputW;
21952202
size_t inLength = imgSizeH * imgSizeW;
@@ -2211,7 +2218,8 @@ void CpuMatrix::avgPoolBackward(Matrix& input,
22112218
int wstart = pw * strideW - paddingW;
22122219
int wend = std::min(wstart + sizeX, imgSizeW);
22132220
wstart = std::max(wstart, 0);
2214-
int poolSize = (hend - hstart) * (wend - wstart);
2221+
int poolSize =
2222+
excludeMode ? (hend - hstart) * (wend - wstart) : sizeY * sizeX;
22152223
CHECK(poolSize);
22162224

22172225
for (int h = hstart; h < hend; ++h) {

paddle/math/Matrix.h

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -911,7 +911,8 @@ class Matrix : public BaseMatrix {
911911
size_t outputH,
912912
size_t outputW,
913913
size_t paddingH,
914-
size_t paddingW) {
914+
size_t paddingW,
915+
bool excludeMode = true) {
915916
LOG(FATAL) << "Not implemeted";
916917
}
917918

@@ -927,9 +928,11 @@ class Matrix : public BaseMatrix {
927928
real scaleTargets,
928929
real scaleOutput,
929930
size_t paddingH,
930-
size_t paddingW) {
931+
size_t paddingW,
932+
bool excludeMode = true) {
931933
LOG(FATAL) << "Not implemeted";
932934
}
935+
933936
/**
934937
* Pooling 3D forward operation, pick out the largest element
935938
* in the sizeX of value
@@ -1458,7 +1461,8 @@ class GpuMatrix : public Matrix {
14581461
size_t outputH,
14591462
size_t outputW,
14601463
size_t paddingH,
1461-
size_t paddingW);
1464+
size_t paddingW,
1465+
bool excludeMode = true);
14621466

14631467
void avgPoolBackward(Matrix& input,
14641468
size_t imgSizeH,
@@ -1472,7 +1476,8 @@ class GpuMatrix : public Matrix {
14721476
real scaleTargets,
14731477
real scaleOutput,
14741478
size_t paddingH,
1475-
size_t paddingW);
1479+
size_t paddingW,
1480+
bool excludeMode = true);
14761481

14771482
void maxPool3DForward(Matrix& inputMat,
14781483
Matrix& maxPoolIdx,
@@ -1730,7 +1735,8 @@ class CpuMatrix : public Matrix {
17301735
size_t outputH,
17311736
size_t outputW,
17321737
size_t paddingH,
1733-
size_t paddingW);
1738+
size_t paddingW,
1739+
bool excludeMode = true);
17341740

17351741
void avgPoolBackward(Matrix& input,
17361742
size_t imgSizeH,
@@ -1744,7 +1750,8 @@ class CpuMatrix : public Matrix {
17441750
real scaleTargets,
17451751
real scaleOutput,
17461752
size_t paddingH,
1747-
size_t paddingW);
1753+
size_t paddingW,
1754+
bool excludeMode = true);
17481755

17491756
void maxPool3DForward(Matrix& inputMat,
17501757
Matrix& maxPoolIdx,

0 commit comments

Comments
 (0)