Skip to content

Commit d02a68c

Browse files
authored
Merge pull request #5768 from NHZlX/add_upsample_layer
Add upsample layer
2 parents fb8c1cf + 663f035 commit d02a68c

File tree

12 files changed

+770
-0
lines changed

12 files changed

+770
-0
lines changed

paddle/cuda/include/hl_cnn.h

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,4 +370,48 @@ extern void hl_maxout_backward(real* inGrad,
370370
size_t featLen,
371371
size_t groups);
372372

373+
/**
374+
* @brief Upsample forward.
375+
* @param[in] inputData input data.
376+
* @param[out] maskData the mask data from MaxPoolWithMaskLayer.
377+
* @param[out] batchSize the batch size of the input.
378+
* @param[in] imgSizeH image height.
379+
* @param[in] imgSizeW image width.
380+
* @param[in] channels the input channels.
381+
* @param[in] outputH the output height.
382+
* @param[in] outputW the output widht.
383+
* @param[out] outputData output data.
384+
*/
385+
extern void hl_upsample_forward(real* inputData,
386+
real* maskData,
387+
size_t batchSize,
388+
size_t imgSizeH,
389+
size_t imgSizeW,
390+
size_t channels,
391+
size_t outputH,
392+
size_t outputW,
393+
real* outputData);
394+
395+
/**
396+
* @brief Upsample backward.
397+
* @param[in] outputGradData the output grad data.
398+
* @param[out] maskData the mask data from MaxPoolWithMaskLayer.
399+
* @param[out] batchSize the batch size of the input.
400+
* @param[in] imgSizeH image height.
401+
* @param[in] imgSizeW image width.
402+
* @param[in] channels the input channels.
403+
* @param[in] outputH the output height.
404+
* @param[in] outputW the output widht.
405+
* @param[out] inputGradData the input grad data.
406+
*/
407+
extern void hl_upsample_backward(real* outputGradData,
408+
real* maskData,
409+
size_t batchSize,
410+
size_t imgSizeH,
411+
size_t imgSizeW,
412+
size_t channels,
413+
size_t outputH,
414+
size_t outputW,
415+
real* inputGradData);
416+
373417
#endif // HL_CNN_H_

paddle/cuda/include/stub/hl_cnn_stub.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,4 +224,24 @@ inline void hl_maxout_backward(real* inGrad,
224224
size_t featLen,
225225
size_t group) {}
226226

227+
inline void hl_upsample_forward(real* inputData,
228+
real* maskData,
229+
size_t batchSize,
230+
size_t imgSizeH,
231+
size_t imgSizeW,
232+
size_t channels,
233+
size_t outputH,
234+
size_t outputW,
235+
real* outputData) {}
236+
237+
inline void hl_upsample_backward(real* outputGradData,
238+
real* maskData,
239+
size_t batchSize,
240+
size_t imgSizeH,
241+
size_t imgSizeW,
242+
size_t channels,
243+
size_t outputH,
244+
size_t outputW,
245+
real* inputGradData) {}
246+
227247
#endif // HL_CNN_STUB_H_

paddle/cuda/src/hl_cuda_cnn.cu

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1028,3 +1028,79 @@ void hl_maxout_backward(real* inGrad,
10281028
num_kernels, inGrad, outGrad, idData, size, featLen, groups);
10291029
CHECK_SYNC("hl_maxout_backward failed");
10301030
}
1031+
1032+
__global__ void upsampleForwardCompute(real* input_data,
1033+
real* mask_data,
1034+
size_t nthreads,
1035+
size_t in_h,
1036+
size_t in_w,
1037+
size_t out_h,
1038+
size_t out_w,
1039+
real* output_data) {
1040+
int index = blockIdx.x * blockDim.x + threadIdx.x;
1041+
if (index < nthreads) {
1042+
int offset = index / (in_w * in_h) * out_h * out_w;
1043+
int upsample_idx = static_cast<int>(mask_data[index]);
1044+
output_data[offset + upsample_idx] = input_data[index];
1045+
}
1046+
}
1047+
1048+
__global__ void upsampleBackwardCompute(real* out_grad,
1049+
real* mask_data,
1050+
size_t nthreads,
1051+
size_t in_h,
1052+
size_t in_w,
1053+
size_t out_h,
1054+
size_t out_w,
1055+
real* input_grad) {
1056+
int index = blockIdx.x * blockDim.x + threadIdx.x;
1057+
if (index < nthreads) {
1058+
int offset = index / (in_w * in_h) * out_h * out_w;
1059+
int upsample_idx = static_cast<int>(mask_data[index]);
1060+
input_grad[index] = out_grad[offset + upsample_idx];
1061+
}
1062+
}
1063+
1064+
void hl_upsample_forward(real* inputData,
1065+
real* maskData,
1066+
size_t batchSize,
1067+
size_t imgSizeH,
1068+
size_t imgSizeW,
1069+
size_t channels,
1070+
size_t outputH,
1071+
size_t outputW,
1072+
real* outputData) {
1073+
int num_kernels = batchSize * imgSizeH * imgSizeW * channels;
1074+
int blocks = (num_kernels + 1024 - 1) / 1024;
1075+
upsampleForwardCompute<<<blocks, 1024, 0, STREAM_DEFAULT>>>(inputData,
1076+
maskData,
1077+
num_kernels,
1078+
imgSizeH,
1079+
imgSizeW,
1080+
outputH,
1081+
outputW,
1082+
outputData);
1083+
CHECK_SYNC("hl_upsample_forward failed");
1084+
}
1085+
1086+
void hl_upsample_backward(real* outputGradData,
1087+
real* maskData,
1088+
size_t batchSize,
1089+
size_t imgSizeH,
1090+
size_t imgSizeW,
1091+
size_t channels,
1092+
size_t outputH,
1093+
size_t outputW,
1094+
real* inputGradData) {
1095+
int num_kernels = batchSize * imgSizeH * imgSizeW * channels;
1096+
int blocks = (num_kernels + 1024 - 1) / 1024;
1097+
upsampleBackwardCompute<<<blocks, 1024, 0, STREAM_DEFAULT>>>(outputGradData,
1098+
maskData,
1099+
num_kernels,
1100+
imgSizeH,
1101+
imgSizeW,
1102+
outputH,
1103+
outputW,
1104+
inputGradData);
1105+
CHECK_SYNC("hl_upsample_backward failed");
1106+
}
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "UpsampleLayer.h"
16+
#include "iostream"
17+
18+
namespace paddle {
19+
20+
REGISTER_LAYER(upsample, UpsampleLayer);
21+
22+
size_t UpsampleLayer::getOutputSize() {
23+
if (upsampleSize_ == 0) {
24+
upsampleSize_ = imgSize_ * scale_ - static_cast<int>(padOutX_);
25+
upsampleSizeY_ = imgSizeY_ * scaleY_ - static_cast<int>(padOutY_);
26+
}
27+
return upsampleSize_ * upsampleSizeY_ * channels_;
28+
}
29+
30+
bool UpsampleLayer::init(const LayerMap& layerMap,
31+
const ParameterMap& parameterMap) {
32+
Layer::init(layerMap, parameterMap);
33+
34+
CHECK_EQ(inputLayers_.size(), 2U);
35+
CHECK_EQ(config_.inputs_size(), 2);
36+
const auto& conf = config_.inputs(0).upsample_conf();
37+
const auto& img_conf = conf.image_conf();
38+
39+
imgSizeY_ =
40+
img_conf.has_img_size_y() ? img_conf.img_size_y() : img_conf.img_size();
41+
imgSize_ = img_conf.img_size();
42+
channels_ = img_conf.channels();
43+
44+
CHECK((conf.has_upsample_size()) || (conf.has_scale()))
45+
<< "scale or upsample_size is required.";
46+
47+
if (conf.has_upsample_size()) {
48+
upsampleSize_ = conf.upsample_size();
49+
upsampleSizeY_ = upsampleSize_;
50+
if (conf.has_upsample_size_y()) {
51+
upsampleSizeY_ = conf.upsample_size_y();
52+
}
53+
} else {
54+
if (!conf.has_scale_y()) {
55+
scale_ = scaleY_ = conf.scale_y();
56+
CHECK_GT(static_cast<int>(scale_), 1);
57+
} else {
58+
scale_ = conf.scale();
59+
scaleY_ = conf.scale_y();
60+
}
61+
padOutX_ = conf.pad_out_x();
62+
padOutY_ = conf.pad_out_y();
63+
CHECK(!padOutX_ || scale_ == 2)
64+
<< "Output height padding compensation requires scale_ == 2";
65+
CHECK(!padOutY_ || scaleY_ == 2)
66+
<< "Output width padding compensation requires scaleY_ == 2";
67+
upsampleSize_ = upsampleSizeY_ = 0;
68+
}
69+
return true;
70+
}
71+
72+
void UpsampleLayer::forward(PassType passType) {
73+
Layer::forward(passType);
74+
75+
MatrixPtr input = getInputValue(0);
76+
MatrixPtr mask = inputLayers_[1]->getOutput("mask").value;
77+
78+
size_t batchSize = input->getHeight();
79+
size_t outSize = getOutputSize();
80+
81+
CHECK_EQ(input->getWidth(), mask->getWidth());
82+
CHECK_EQ(mask->getHeight(), batchSize);
83+
resetOutput(batchSize, outSize);
84+
85+
MatrixPtr output = getOutputValue();
86+
output->upsampleForward(*input,
87+
*mask,
88+
imgSize_,
89+
imgSizeY_,
90+
channels_,
91+
upsampleSize_,
92+
upsampleSizeY_);
93+
}
94+
95+
void UpsampleLayer::backward(const UpdateCallback& callback) {
96+
MatrixPtr mask = inputLayers_[1]->getOutput("mask").value;
97+
MatrixPtr inputGrad = getInputGrad(0);
98+
MatrixPtr outputGrad = getOutputGrad();
99+
inputGrad->upsampleBackward(*outputGrad,
100+
*mask,
101+
imgSize_,
102+
imgSizeY_,
103+
channels_,
104+
upsampleSize_,
105+
upsampleSizeY_);
106+
}
107+
108+
} // namespace paddle

paddle/gserver/layers/UpsampleLayer.h

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include <vector>
18+
#include "Layer.h"
19+
#include "paddle/math/Matrix.h"
20+
#include "paddle/utils/Logging.h"
21+
#include "paddle/utils/Stat.h"
22+
23+
namespace paddle {
24+
25+
/**
26+
* This layer transpose the pooling process.
27+
* It takes two input, the first input is the input data, and
28+
* the second is the mask data from the max-pool-with-mask layer.
29+
*
30+
*/
31+
32+
class UpsampleLayer : public Layer {
33+
public:
34+
explicit UpsampleLayer(const LayerConfig& config) : Layer(config) {}
35+
~UpsampleLayer() {}
36+
37+
bool init(const LayerMap& layerMap,
38+
const ParameterMap& parameterMap) override;
39+
40+
void forward(PassType passType) override;
41+
void backward(const UpdateCallback& callback) override;
42+
43+
size_t getOutputSize();
44+
45+
protected:
46+
size_t scale_, scaleY_;
47+
size_t upsampleSize_, upsampleSizeY_;
48+
size_t padOutX_, padOutY_;
49+
size_t imgSize_, imgSizeY_;
50+
size_t channels_;
51+
};
52+
53+
} // namespace paddle

paddle/gserver/tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ gserver_test(test_BatchNorm)
2727
gserver_test(test_KmaxSeqScore)
2828
gserver_test(test_Expand)
2929
gserver_test(test_MaxPoolingWithMaskOutput)
30+
gserver_test(test_Upsample)
3031

3132
set(PYTHON_PATH
3233
${PADDLE_SOURCE_DIR}/paddle/.set_python_path.sh -d

0 commit comments

Comments
 (0)