Skip to content

Commit 3e6f768

Browse files
authored
Merge pull request #4891 from NHZlX/poolmaxpool_with_mask
max pool Layer with mask
2 parents 4adc8a7 + e19b931 commit 3e6f768

File tree

14 files changed

+357
-29
lines changed

14 files changed

+357
-29
lines changed

paddle/cuda/include/hl_cnn.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ limitations under the License. */
1818
#include "hl_base.h"
1919

2020
/**
21-
* @brief Maximum pool forward.
21+
* @brief Maximum pool forward with Mask output.
2222
*
2323
* @param[in] frameCnt batch size of input image.
2424
* @param[in] inputData input data.
@@ -35,7 +35,7 @@ limitations under the License. */
3535
* @param[in] paddingW padding width.
3636
* @param[out] tgtData output data.
3737
* @param[in] tgtStride stride between output data samples.
38-
*
38+
* @param[out] maskData the location indices of select max data.
3939
*/
4040
extern void hl_maxpool_forward(const int frameCnt,
4141
const real* inputData,
@@ -51,7 +51,8 @@ extern void hl_maxpool_forward(const int frameCnt,
5151
const int paddingH,
5252
const int paddingW,
5353
real* tgtData,
54-
const int tgtStride);
54+
const int tgtStride,
55+
real* maskData = NULL);
5556

5657
/**
5758
* @brief Maximum pool backward.

paddle/cuda/include/stub/hl_cnn_stub.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ inline void hl_maxpool_forward(const int frameCnt,
3131
const int paddingH,
3232
const int paddingW,
3333
real* tgtData,
34-
const int tgtStride) {}
34+
const int tgtStride,
35+
real* MaskData) {}
3536

3637
inline void hl_maxpool_backward(const int frameCnt,
3738
const real* inputData,

paddle/cuda/src/hl_cuda_cnn.cu

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ __global__ void KeMaxPoolForward(const int nthreads,
3131
const int offsetH,
3232
const int offsetW,
3333
real* tgtData,
34-
const int tgtStride) {
34+
const int tgtStride,
35+
real* maskData) {
3536
int index = blockIdx.x * blockDim.x + threadIdx.x;
3637
if (index < nthreads) {
3738
int pw = index % pooledW;
@@ -45,16 +46,22 @@ __global__ void KeMaxPoolForward(const int nthreads,
4546
hstart = max(hstart, 0);
4647
wstart = max(wstart, 0);
4748
real maxval = -FLT_MAX;
49+
int max_index = -1;
4850
inputData += (frameNum * channels + c) * height * width;
4951
for (int h = hstart; h < hend; ++h) {
5052
for (int w = wstart; w < wend; ++w) {
51-
if (maxval < inputData[h * width + w])
52-
maxval = inputData[h * width + w];
53+
if (maxval < inputData[h * width + w]) {
54+
max_index = h * width + w;
55+
maxval = inputData[max_index];
56+
}
5357
}
5458
}
5559
int tgtIndex =
5660
index % (pooledW * pooledH * channels) + frameNum * tgtStride;
5761
tgtData[tgtIndex] = maxval;
62+
if (maskData != NULL) {
63+
maskData[tgtIndex] = max_index;
64+
}
5865
}
5966
}
6067

@@ -72,7 +79,8 @@ void hl_maxpool_forward(const int frameCnt,
7279
const int paddingH,
7380
const int paddingW,
7481
real* tgtData,
75-
const int tgtStride) {
82+
const int tgtStride,
83+
real* maskData) {
7684
int num_kernels = pooledH * pooledW * channels * frameCnt;
7785
int blocks = (num_kernels + 1024 - 1) / 1024;
7886
dim3 threads(1024, 1);
@@ -92,7 +100,8 @@ void hl_maxpool_forward(const int frameCnt,
92100
paddingH,
93101
paddingW,
94102
tgtData,
95-
tgtStride);
103+
tgtStride,
104+
maskData);
96105
CHECK_SYNC("hl_maxpool_forward failed");
97106
}
98107

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "MaxPoolWithMaskLayer.h"
16+
#include "paddle/utils/Logging.h"
17+
#include "paddle/utils/Stat.h"
18+
19+
namespace paddle {
20+
21+
bool MaxPoolWithMaskLayer::init(const LayerMap& layerMap,
22+
const ParameterMap& parameterMap) {
23+
PoolLayer::init(layerMap, parameterMap);
24+
setOutput("mask", &mask_);
25+
return true;
26+
}
27+
28+
size_t MaxPoolWithMaskLayer::getSize() {
29+
CHECK_EQ(inputLayers_.size(), 1UL);
30+
size_t layerSize = 0;
31+
32+
outputY_ = outputSize(imgSizeY_,
33+
sizeY_,
34+
confPaddingY_,
35+
strideY_,
36+
/* caffeMode */ false);
37+
outputX_ = outputSize(imgSize_,
38+
sizeX_,
39+
confPadding_,
40+
stride_,
41+
/* caffeMode */ false);
42+
43+
layerSize = outputX_ * outputY_ * channels_;
44+
getOutput().setFrameHeight(outputY_);
45+
getOutput().setFrameWidth(outputX_);
46+
47+
return layerSize;
48+
}
49+
50+
void MaxPoolWithMaskLayer::forward(PassType passType) {
51+
size_t size = getSize();
52+
MatrixPtr inputV = inputLayers_[0]->getOutputValue();
53+
int batchSize = inputV->getHeight();
54+
resetOutput(batchSize, size);
55+
56+
MatrixPtr outV = getOutputValue();
57+
CHECK_EQ(size, outV->getWidth());
58+
59+
resetSpecifyOutput(mask_,
60+
batchSize,
61+
size,
62+
/* isValueClean */ false,
63+
/* isGradClean */ true);
64+
65+
MatrixPtr maskV = mask_.value;
66+
outV->maxPoolForward(*inputV,
67+
imgSizeY_,
68+
imgSize_,
69+
channels_,
70+
sizeX_,
71+
sizeY_,
72+
strideY_,
73+
stride_,
74+
outputY_,
75+
outputX_,
76+
confPaddingY_,
77+
confPadding_,
78+
maskV);
79+
}
80+
81+
void MaxPoolWithMaskLayer::backward(const UpdateCallback& callback) {
82+
(void)callback;
83+
if (NULL == getInputGrad(0)) {
84+
return;
85+
}
86+
87+
MatrixPtr outGrad = getOutputGrad();
88+
MatrixPtr inputV = inputLayers_[0]->getOutputValue();
89+
MatrixPtr outV = getOutputValue();
90+
MatrixPtr inputGrad = inputLayers_[0]->getOutputGrad();
91+
92+
inputGrad->maxPoolBackward(*inputV,
93+
imgSizeY_,
94+
imgSize_,
95+
*outGrad,
96+
*outV,
97+
sizeX_,
98+
sizeY_,
99+
strideY_,
100+
stride_,
101+
outputY_,
102+
outputX_,
103+
1,
104+
1,
105+
confPaddingY_,
106+
confPadding_);
107+
}
108+
109+
} // namespace paddle
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include <vector>
18+
#include "PoolLayer.h"
19+
#include "paddle/math/Matrix.h"
20+
21+
namespace paddle {
22+
/**
23+
* @brief Basic parent layer of different kinds of pooling
24+
*/
25+
class MaxPoolWithMaskLayer : public PoolLayer {
26+
protected:
27+
Argument mask_;
28+
29+
public:
30+
explicit MaxPoolWithMaskLayer(const LayerConfig& config)
31+
: PoolLayer(config) {}
32+
33+
size_t getSize();
34+
35+
void forward(PassType passType) override;
36+
void backward(const UpdateCallback& callback = nullptr) override;
37+
bool init(const LayerMap& layerMap,
38+
const ParameterMap& parameterMap) override;
39+
};
40+
} // namespace paddle

paddle/gserver/layers/PoolLayer.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#include "PoolLayer.h"
16+
#include "MaxPoolWithMaskLayer.h"
1617
#include "PoolProjectionLayer.h"
1718
#include "paddle/utils/Logging.h"
1819
#ifdef PADDLE_WITH_CUDA
@@ -44,7 +45,6 @@ bool PoolLayer::init(const LayerMap& layerMap,
4445
strideY_ = conf.has_stride_y() ? conf.stride_y() : conf.stride();
4546
confPaddingY_ = conf.has_padding_y() ? conf.padding_y() : conf.padding();
4647
outputY_ = conf.has_output_y() ? conf.output_y() : conf.output_x();
47-
4848
return true;
4949
}
5050

@@ -57,6 +57,8 @@ Layer* PoolLayer::create(const LayerConfig& config) {
5757
} else if (CudnnPoolLayer::typeCheck(pool)) {
5858
return new CudnnPoolLayer(config);
5959
#endif
60+
} else if (pool == "max-pool-with-mask") {
61+
return new MaxPoolWithMaskLayer(config);
6062
} else {
6163
LOG(FATAL) << "Unknown pool type: " << pool;
6264
return nullptr;

paddle/gserver/tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ gserver_test(test_ConvUnify)
2424
gserver_test(test_BatchNorm)
2525
gserver_test(test_KmaxSeqScore)
2626
gserver_test(test_Expand)
27+
gserver_test(test_MaxPoolingWithMaskOutput)
2728

2829
########## test_Mkldnn layers and activations ##########
2930
if(WITH_MKLDNN)

paddle/gserver/tests/test_LayerGrad.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1234,6 +1234,7 @@ void testPoolLayer2(const string& poolType, bool trans, bool useGpu) {
12341234
TEST(Layer, PoolLayer) {
12351235
testPoolLayer("avg-projection", /* trans= */ false, /* useGpu= */ false);
12361236
testPoolLayer("max-projection", /* trans= */ false, /* useGpu= */ false);
1237+
testPoolLayer("max-pool-with-mask", /* trans= */ false, /* useGpu= */ false);
12371238

12381239
#ifdef PADDLE_WITH_CUDA
12391240
testPoolLayer("avg-projection", /* trans= */ false, /* useGpu= */ true);
@@ -1242,6 +1243,7 @@ TEST(Layer, PoolLayer) {
12421243
testPoolLayer("cudnn-avg-pool", /* trans= */ false, /* useGpu= */ true);
12431244
testPoolLayer2("cudnn-max-pool", /* trans= */ false, /* useGpu= */ true);
12441245
testPoolLayer2("cudnn-avg-pool", /* trans= */ false, /* useGpu= */ true);
1246+
testPoolLayer("max-pool-with-mask", /* trans= */ false, /* useGpu= */ true);
12451247
#endif
12461248
}
12471249

0 commit comments

Comments
 (0)