Skip to content

Commit 8295eb9

Browse files
authored
Merge pull request #287 from gangliao/bilinear
Add bilinear interpolation layer
2 parents cfc965d + f27ff4d commit 8295eb9

File tree

16 files changed

+925
-5
lines changed

16 files changed

+925
-5
lines changed

doc/ui/api/trainer_config_helpers/layers.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,12 @@ interpolation_layer
275275
:members: interpolation_layer
276276
:noindex:
277277

278+
bilinear_interp_layer
279+
----------------------
280+
.. automodule:: paddle.trainer_config_helpers.layers
281+
:members: bilinear_interp_layer
282+
:noindex:
283+
278284
power_layer
279285
-----------
280286
.. automodule:: paddle.trainer_config_helpers.layers

paddle/cuda/include/hl_cnn.h

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,70 @@ extern void hl_CMRNorm_backward(
240240
size_t channels, size_t height, size_t width, size_t sizeX,
241241
real alpha, real beta);
242242

243+
/**
244+
* @brief Bilinear interpolation forward.
245+
*
246+
* @param[in] inData input value.
247+
* @param[in] inImgH input image height.
248+
* @param[in] inImgW input image width.
249+
* @param[in] inputH input batchSize.
250+
* @param[in] inputW input image data dim.
251+
* @param[out] outData output value.
252+
* @param[in] outImgH output image height.
253+
* @param[in] outImgW output image width.
254+
* @param[in] outputH output batchSize.
255+
* @param[in] outputW output image data dim.
256+
* @param[in] numChannels number of channels.
257+
* @param[in] ratioH inImgH / outImgH.
258+
* @param[in] ratioW inImgW / outImgW.
259+
*
260+
*/
261+
extern void hl_bilinear_forward(const real* inData,
262+
const size_t inImgH,
263+
const size_t inImgW,
264+
const size_t inputH,
265+
const size_t inputW,
266+
real* outData,
267+
const size_t outImgH,
268+
const size_t outImgW,
269+
const size_t outputH,
270+
const size_t outputW,
271+
const size_t numChannels,
272+
const real ratioH,
273+
const real ratioW);
274+
275+
/**
276+
* @brief Bilinear interpolation backward.
277+
*
278+
* @param[out] inGrad input gradient.
279+
* @param[in] inImgH input image height.
280+
* @param[in] inImgW input image width.
281+
* @param[in] inputH input batchSize.
282+
* @param[in] inputW input image data dim.
283+
* @param[in] outGrad output gradient.
284+
* @param[in] outImgH output image height.
285+
* @param[in] outImgW output image width.
286+
* @param[in] outputH output batchSize.
287+
* @param[in] outputW output image data dim.
288+
* @param[in] numChannels number of channels.
289+
* @param[in] ratioH inImgH / outImgH.
290+
* @param[in] ratioW inImgW / outImgW.
291+
*
292+
*/
293+
extern void hl_bilinear_backward(real* inGrad,
294+
const size_t inImgH,
295+
const size_t inImgW,
296+
const size_t inputH,
297+
const size_t inputW,
298+
const real* outGrad,
299+
const size_t outImgH,
300+
const size_t outImgW,
301+
const size_t outputH,
302+
const size_t outputW,
303+
const size_t numChannels,
304+
const real ratioH,
305+
const real ratioW);
306+
243307
/**
244308
* @brief MaxOut forward.
245309
*

paddle/cuda/include/stub/hl_cnn_stub.h

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,34 @@ inline void hl_CMRNorm_backward(
8989
size_t channels, size_t height, size_t width, size_t sizeX,
9090
real alpha, real beta) {}
9191

92+
inline void hl_bilinear_forward(const real* inData,
93+
const size_t inImgH,
94+
const size_t inImgW,
95+
const size_t inputH,
96+
const size_t inputW,
97+
real* outData,
98+
const size_t outImgH,
99+
const size_t outImgW,
100+
const size_t outputH,
101+
const size_t outputW,
102+
const size_t numChannels,
103+
const real ratioH,
104+
const real ratioW) {}
105+
106+
inline void hl_bilinear_backward(real* inGrad,
107+
const size_t inImgH,
108+
const size_t inImgW,
109+
const size_t inputH,
110+
const size_t inputW,
111+
const real* outGrad,
112+
const size_t outImgH,
113+
const size_t outImgW,
114+
const size_t outputH,
115+
const size_t outputW,
116+
const size_t numChannels,
117+
const real ratioH,
118+
const real ratioW) {}
119+
92120
inline void hl_maxout_forward(
93121
const real* inData, real* outData, int* idData,
94122
size_t batchSize, size_t size, size_t featLen, size_t group) {}

paddle/cuda/src/hl_cuda_cnn.cu

Lines changed: 133 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -522,7 +522,7 @@ void hl_CMRNorm_backward(size_t frameCnt, const real* inV,
522522
size_t height, size_t width, size_t sizeX,
523523
real alpha, real beta) {
524524
size_t threadsNum = frameCnt * height * width;
525-
size_t blocksX = (threadsNum + 1024 -1) / 1024;
525+
size_t blocksX = (threadsNum + 1024 - 1) / 1024;
526526
size_t blocksY = 1;
527527
dim3 threads(1024, 1);
528528
dim3 grid(blocksX, blocksY);
@@ -532,6 +532,138 @@ void hl_CMRNorm_backward(size_t frameCnt, const real* inV,
532532
CHECK_SYNC("hl_CMRNorm_backward");
533533
}
534534

535+
__global__ void KeBilinearInterpFw(const real* in,
536+
const size_t inImgH,
537+
const size_t inImgW,
538+
const size_t inputH,
539+
const size_t inputW,
540+
real* out,
541+
const size_t outImgH,
542+
const size_t outImgW,
543+
const size_t outputH,
544+
const size_t outputW,
545+
const size_t numChannels,
546+
const real ratioH,
547+
const real ratioW) {
548+
int nthreads = outputH * outputW;
549+
int tid = blockIdx.x * blockDim.x + threadIdx.x;
550+
if (tid < nthreads) {
551+
int outIdH = tid / outputW;
552+
int outIdW = tid % outputW;
553+
int inImgSize = inputW / numChannels;
554+
int outImgSize = outputW / numChannels;
555+
int channelId = outIdW / outImgSize;
556+
557+
int outImgIdy = (outIdW % outImgSize) / outImgW;
558+
int inImgIdy = ratioH * outImgIdy;
559+
int hId = (inImgIdy < inImgH - 1) ? 1 : 0;
560+
real h1lambda = ratioH * outImgIdy - inImgIdy;
561+
real h2lambda = 1.f - h1lambda;
562+
563+
int outImgIdx = tid % outImgW;
564+
int inImgIdx = ratioW * outImgIdx;
565+
int wId = (inImgIdx < inImgW - 1) ? 1 : 0;
566+
real w1lambda = ratioW * outImgIdx - inImgIdx;
567+
real w2lambda = 1.f - w1lambda;
568+
569+
const real* inPos =
570+
&in[outIdH * inputW + channelId * inImgSize + inImgIdy * inImgW + inImgIdx];
571+
572+
// bilinear interpolation
573+
out[outIdH * outputW + outIdW] =
574+
h2lambda * (w2lambda * inPos[0] + w1lambda * inPos[wId]) +
575+
h1lambda * (w2lambda * inPos[hId * inImgW] + w1lambda * inPos[hId * inImgW + wId]);
576+
}
577+
}
578+
579+
void hl_bilinear_forward(const real* inData,
580+
const size_t inImgH,
581+
const size_t inImgW,
582+
const size_t inputH,
583+
const size_t inputW,
584+
real* outData,
585+
const size_t outImgH,
586+
const size_t outImgW,
587+
const size_t outputH,
588+
const size_t outputW,
589+
const size_t numChannels,
590+
const real ratioH,
591+
const real ratioW) {
592+
int threadNum = outputH * outputW;
593+
int blocks = (threadNum + 1024 - 1) / 1024;
594+
595+
KeBilinearInterpFw<<< blocks, 1024, 0, STREAM_DEFAULT>>>(
596+
inData, inImgH, inImgW, inputH, inputW, outData, outImgH,
597+
outImgW, outputH, outputW, numChannels, ratioH, ratioW);
598+
CHECK_SYNC("hl_bilinear_forward failed");
599+
}
600+
601+
__global__ void KeBilinearInterpBw(real* in,
602+
const size_t inImgH,
603+
const size_t inImgW,
604+
const size_t inputH,
605+
const size_t inputW,
606+
const real* out,
607+
const size_t outImgH,
608+
const size_t outImgW,
609+
const size_t outputH,
610+
const size_t outputW,
611+
const size_t numChannels,
612+
const real ratioH,
613+
const real ratioW) {
614+
int nthreads = outputH * outputW;
615+
int tid = blockIdx.x * blockDim.x + threadIdx.x;
616+
if (tid < nthreads) {
617+
int outIdH = tid / outputW;
618+
int outIdW = tid % outputW;
619+
int inImgSize = inputW / numChannels;
620+
int outImgSize = outputW / numChannels;
621+
int channelId = outIdW / outImgSize;
622+
623+
int outImgIdy = (outIdW % outImgSize) / outImgW;
624+
int inImgIdy = ratioH * outImgIdy;
625+
int hId = (inImgIdy < inImgH - 1) ? 1 : 0;
626+
real h1lambda = ratioH * outImgIdy - inImgIdy;
627+
real h2lambda = 1.f - h1lambda;
628+
629+
int outImgIdx = tid % outImgW;
630+
int inImgIdx = ratioW * outImgIdx;
631+
int wId = (inImgIdx < inImgW - 1) ? 1 : 0;
632+
real w1lambda = ratioW * outImgIdx - inImgIdx;
633+
real w2lambda = 1.f - w1lambda;
634+
635+
real* inPos =
636+
&in[outIdH * inputW + channelId * inImgSize + inImgIdy * inImgW + inImgIdx];
637+
const real* outPos = &out[outIdH * outputW + outIdW];
638+
atomicAdd(&inPos[0], h2lambda * w2lambda * outPos[0]);
639+
atomicAdd(&inPos[wId], h2lambda * w1lambda * outPos[0]);
640+
atomicAdd(&inPos[hId * inImgW], h1lambda * w2lambda * outPos[0]);
641+
atomicAdd(&inPos[hId * inImgW + wId], h1lambda * w1lambda * outPos[0]);
642+
}
643+
}
644+
645+
void hl_bilinear_backward(real* inGrad,
646+
const size_t inImgH,
647+
const size_t inImgW,
648+
const size_t inputH,
649+
const size_t inputW,
650+
const real* outGrad,
651+
const size_t outImgH,
652+
const size_t outImgW,
653+
const size_t outputH,
654+
const size_t outputW,
655+
const size_t numChannels,
656+
const real ratioH,
657+
const real ratioW) {
658+
int threadNum = outputH * outputW;
659+
int blocks = (threadNum + 1024 - 1) / 1024;
660+
661+
KeBilinearInterpBw<<< blocks, 1024, 0, STREAM_DEFAULT>>>(
662+
inGrad, inImgH, inImgW, inputH, inputW, outGrad, outImgH,
663+
outImgW, outputH, outputW, numChannels, ratioH, ratioW);
664+
CHECK_SYNC("hl_bilinear_backward failed");
665+
}
666+
535667
__global__ void maxoutFpCompute(size_t nthreads, const real * inData,
536668
real * outData, int* idData,
537669
size_t size, size_t featLen, size_t groups) {
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "BilinearInterpLayer.h"
16+
#include "paddle/utils/Logging.h"
17+
#include "paddle/utils/Stat.h"
18+
19+
namespace paddle {
20+
21+
REGISTER_LAYER(bilinear_interp, BilinearInterpLayer);
22+
23+
size_t BilinearInterpLayer::getSize() {
24+
inImgH_ = inputLayers_[0]->getOutput().getFrameHeight();
25+
inImgW_ = inputLayers_[0]->getOutput().getFrameWidth();
26+
27+
const BilinearInterpConfig& conf = config_.inputs(0).bilinear_interp_conf();
28+
if (inImgH_ == 0) {
29+
inImgH_ = conf.img_size_y();
30+
}
31+
if (inImgW_ == 0) {
32+
inImgW_ = conf.img_size_x();
33+
}
34+
35+
outImgH_ = conf.out_size_y();
36+
outImgW_ = conf.out_size_x();
37+
numChannels_ = conf.num_channels();
38+
39+
CHECK(outImgH_ > 0 && outImgW_ > 0);
40+
CHECK(inImgH_ > 0 && inImgW_ > 0);
41+
CHECK(numChannels_);
42+
43+
ratioH_ = (outImgH_ > 1) ?
44+
static_cast<real>(inImgH_ - 1) / (outImgH_ - 1) : 0.f;
45+
ratioW_ = (outImgW_ > 1) ?
46+
static_cast<real>(inImgW_ - 1) / (outImgW_ - 1) : 0.f;
47+
48+
getOutput().setFrameHeight(outImgH_);
49+
getOutput().setFrameWidth(outImgW_);
50+
return outImgH_ * outImgW_ * numChannels_;
51+
}
52+
53+
bool BilinearInterpLayer::init(const LayerMap& layerMap,
54+
const ParameterMap& parameterMap) {
55+
/* Initialize the basic parent class */
56+
Layer::init(layerMap, parameterMap);
57+
58+
CHECK_EQ(1, config_.inputs_size());
59+
60+
return true;
61+
}
62+
63+
void BilinearInterpLayer::forward(PassType passType) {
64+
Layer::forward(passType);
65+
66+
size_t batchSize = getInput(0).getBatchSize();
67+
size_t size = getSize();
68+
{
69+
REGISTER_TIMER_INFO("FwResetTimer", getName().c_str());
70+
resetOutput(batchSize, size);
71+
}
72+
73+
MatrixPtr inV = getInputValue(0);
74+
MatrixPtr outV = getOutputValue();
75+
{
76+
REGISTER_TIMER_INFO("FwBilinearInterpTimer", getName().c_str());
77+
outV->bilinearForward(*inV, inImgH_, inImgW_, outImgH_, outImgW_,
78+
numChannels_, ratioH_, ratioW_);
79+
}
80+
}
81+
82+
void BilinearInterpLayer::backward(const UpdateCallback& callback) {
83+
(void) callback;
84+
85+
MatrixPtr inputG = getInputGrad(0);
86+
MatrixPtr outG = getOutputGrad();
87+
{
88+
REGISTER_TIMER_INFO("BwBilinearInterpTimer", getName().c_str());
89+
if (inputG) {
90+
inputG->bilinearBackward(*outG, outImgH_, outImgW_, inImgH_, inImgW_,
91+
numChannels_, ratioH_, ratioW_);
92+
}
93+
}
94+
}
95+
} // namespace paddle

0 commit comments

Comments
 (0)