Skip to content

Commit 9dd588b

Browse files
committed
fix merge conflicts
2 parents 61444d9 + 8295eb9 commit 9dd588b

31 files changed

+2129
-344
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,6 @@ build/
55
.vscode
66
.idea
77
.project
8+
.cproject
89
.pydevproject
10+
Makefile

doc/ui/api/trainer_config_helpers/layers.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,12 @@ interpolation_layer
287287
:members: interpolation_layer
288288
:noindex:
289289

290+
bilinear_interp_layer
291+
----------------------
292+
.. automodule:: paddle.trainer_config_helpers.layers
293+
:members: bilinear_interp_layer
294+
:noindex:
295+
290296
power_layer
291297
-----------
292298
.. automodule:: paddle.trainer_config_helpers.layers

paddle/cuda/include/hl_cnn.h

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,70 @@ extern void hl_CMRNorm_backward(
246246
size_t channels, size_t height, size_t width, size_t sizeX,
247247
real alpha, real beta);
248248

249+
/**
250+
* @brief Bilinear interpolation forward.
251+
*
252+
* @param[in] inData input value.
253+
* @param[in] inImgH input image height.
254+
* @param[in] inImgW input image width.
255+
* @param[in] inputH input batchSize.
256+
* @param[in] inputW input image data dim.
257+
* @param[out] outData output value.
258+
* @param[in] outImgH output image height.
259+
* @param[in] outImgW output image width.
260+
* @param[in] outputH output batchSize.
261+
* @param[in] outputW output image data dim.
262+
* @param[in] numChannels number of channels.
263+
* @param[in] ratioH inImgH / outImgH.
264+
* @param[in] ratioW inImgW / outImgW.
265+
*
266+
*/
267+
extern void hl_bilinear_forward(const real* inData,
268+
const size_t inImgH,
269+
const size_t inImgW,
270+
const size_t inputH,
271+
const size_t inputW,
272+
real* outData,
273+
const size_t outImgH,
274+
const size_t outImgW,
275+
const size_t outputH,
276+
const size_t outputW,
277+
const size_t numChannels,
278+
const real ratioH,
279+
const real ratioW);
280+
281+
/**
282+
* @brief Bilinear interpolation backward.
283+
*
284+
* @param[out] inGrad input gradient.
285+
* @param[in] inImgH input image height.
286+
* @param[in] inImgW input image width.
287+
* @param[in] inputH input batchSize.
288+
* @param[in] inputW input image data dim.
289+
* @param[in] outGrad output gradient.
290+
* @param[in] outImgH output image height.
291+
* @param[in] outImgW output image width.
292+
* @param[in] outputH output batchSize.
293+
* @param[in] outputW output image data dim.
294+
* @param[in] numChannels number of channels.
295+
* @param[in] ratioH inImgH / outImgH.
296+
* @param[in] ratioW inImgW / outImgW.
297+
*
298+
*/
299+
extern void hl_bilinear_backward(real* inGrad,
300+
const size_t inImgH,
301+
const size_t inImgW,
302+
const size_t inputH,
303+
const size_t inputW,
304+
const real* outGrad,
305+
const size_t outImgH,
306+
const size_t outImgW,
307+
const size_t outputH,
308+
const size_t outputW,
309+
const size_t numChannels,
310+
const real ratioH,
311+
const real ratioW);
312+
249313
/**
250314
* @brief MaxOut forward.
251315
*

paddle/cuda/include/stub/hl_cnn_stub.h

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,34 @@ inline void hl_CMRNorm_backward(
9191
size_t channels, size_t height, size_t width, size_t sizeX,
9292
real alpha, real beta) {}
9393

94+
inline void hl_bilinear_forward(const real* inData,
95+
const size_t inImgH,
96+
const size_t inImgW,
97+
const size_t inputH,
98+
const size_t inputW,
99+
real* outData,
100+
const size_t outImgH,
101+
const size_t outImgW,
102+
const size_t outputH,
103+
const size_t outputW,
104+
const size_t numChannels,
105+
const real ratioH,
106+
const real ratioW) {}
107+
108+
inline void hl_bilinear_backward(real* inGrad,
109+
const size_t inImgH,
110+
const size_t inImgW,
111+
const size_t inputH,
112+
const size_t inputW,
113+
const real* outGrad,
114+
const size_t outImgH,
115+
const size_t outImgW,
116+
const size_t outputH,
117+
const size_t outputW,
118+
const size_t numChannels,
119+
const real ratioH,
120+
const real ratioW) {}
121+
94122
inline void hl_maxout_forward(
95123
const real* inData, real* outData, int* idData,
96124
size_t batchSize, size_t size, size_t featLen, size_t group) {}

paddle/cuda/src/hl_cuda_cnn.cu

Lines changed: 133 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -528,7 +528,7 @@ void hl_CMRNorm_backward(size_t frameCnt, const real* inV,
528528
size_t height, size_t width, size_t sizeX,
529529
real alpha, real beta) {
530530
size_t threadsNum = frameCnt * height * width;
531-
size_t blocksX = (threadsNum + 1024 -1) / 1024;
531+
size_t blocksX = (threadsNum + 1024 - 1) / 1024;
532532
size_t blocksY = 1;
533533
dim3 threads(1024, 1);
534534
dim3 grid(blocksX, blocksY);
@@ -538,6 +538,138 @@ void hl_CMRNorm_backward(size_t frameCnt, const real* inV,
538538
CHECK_SYNC("hl_CMRNorm_backward");
539539
}
540540

541+
__global__ void KeBilinearInterpFw(const real* in,
542+
const size_t inImgH,
543+
const size_t inImgW,
544+
const size_t inputH,
545+
const size_t inputW,
546+
real* out,
547+
const size_t outImgH,
548+
const size_t outImgW,
549+
const size_t outputH,
550+
const size_t outputW,
551+
const size_t numChannels,
552+
const real ratioH,
553+
const real ratioW) {
554+
int nthreads = outputH * outputW;
555+
int tid = blockIdx.x * blockDim.x + threadIdx.x;
556+
if (tid < nthreads) {
557+
int outIdH = tid / outputW;
558+
int outIdW = tid % outputW;
559+
int inImgSize = inputW / numChannels;
560+
int outImgSize = outputW / numChannels;
561+
int channelId = outIdW / outImgSize;
562+
563+
int outImgIdy = (outIdW % outImgSize) / outImgW;
564+
int inImgIdy = ratioH * outImgIdy;
565+
int hId = (inImgIdy < inImgH - 1) ? 1 : 0;
566+
real h1lambda = ratioH * outImgIdy - inImgIdy;
567+
real h2lambda = 1.f - h1lambda;
568+
569+
int outImgIdx = tid % outImgW;
570+
int inImgIdx = ratioW * outImgIdx;
571+
int wId = (inImgIdx < inImgW - 1) ? 1 : 0;
572+
real w1lambda = ratioW * outImgIdx - inImgIdx;
573+
real w2lambda = 1.f - w1lambda;
574+
575+
const real* inPos =
576+
&in[outIdH * inputW + channelId * inImgSize + inImgIdy * inImgW + inImgIdx];
577+
578+
// bilinear interpolation
579+
out[outIdH * outputW + outIdW] =
580+
h2lambda * (w2lambda * inPos[0] + w1lambda * inPos[wId]) +
581+
h1lambda * (w2lambda * inPos[hId * inImgW] + w1lambda * inPos[hId * inImgW + wId]);
582+
}
583+
}
584+
585+
void hl_bilinear_forward(const real* inData,
586+
const size_t inImgH,
587+
const size_t inImgW,
588+
const size_t inputH,
589+
const size_t inputW,
590+
real* outData,
591+
const size_t outImgH,
592+
const size_t outImgW,
593+
const size_t outputH,
594+
const size_t outputW,
595+
const size_t numChannels,
596+
const real ratioH,
597+
const real ratioW) {
598+
int threadNum = outputH * outputW;
599+
int blocks = (threadNum + 1024 - 1) / 1024;
600+
601+
KeBilinearInterpFw<<< blocks, 1024, 0, STREAM_DEFAULT>>>(
602+
inData, inImgH, inImgW, inputH, inputW, outData, outImgH,
603+
outImgW, outputH, outputW, numChannels, ratioH, ratioW);
604+
CHECK_SYNC("hl_bilinear_forward failed");
605+
}
606+
607+
__global__ void KeBilinearInterpBw(real* in,
608+
const size_t inImgH,
609+
const size_t inImgW,
610+
const size_t inputH,
611+
const size_t inputW,
612+
const real* out,
613+
const size_t outImgH,
614+
const size_t outImgW,
615+
const size_t outputH,
616+
const size_t outputW,
617+
const size_t numChannels,
618+
const real ratioH,
619+
const real ratioW) {
620+
int nthreads = outputH * outputW;
621+
int tid = blockIdx.x * blockDim.x + threadIdx.x;
622+
if (tid < nthreads) {
623+
int outIdH = tid / outputW;
624+
int outIdW = tid % outputW;
625+
int inImgSize = inputW / numChannels;
626+
int outImgSize = outputW / numChannels;
627+
int channelId = outIdW / outImgSize;
628+
629+
int outImgIdy = (outIdW % outImgSize) / outImgW;
630+
int inImgIdy = ratioH * outImgIdy;
631+
int hId = (inImgIdy < inImgH - 1) ? 1 : 0;
632+
real h1lambda = ratioH * outImgIdy - inImgIdy;
633+
real h2lambda = 1.f - h1lambda;
634+
635+
int outImgIdx = tid % outImgW;
636+
int inImgIdx = ratioW * outImgIdx;
637+
int wId = (inImgIdx < inImgW - 1) ? 1 : 0;
638+
real w1lambda = ratioW * outImgIdx - inImgIdx;
639+
real w2lambda = 1.f - w1lambda;
640+
641+
real* inPos =
642+
&in[outIdH * inputW + channelId * inImgSize + inImgIdy * inImgW + inImgIdx];
643+
const real* outPos = &out[outIdH * outputW + outIdW];
644+
atomicAdd(&inPos[0], h2lambda * w2lambda * outPos[0]);
645+
atomicAdd(&inPos[wId], h2lambda * w1lambda * outPos[0]);
646+
atomicAdd(&inPos[hId * inImgW], h1lambda * w2lambda * outPos[0]);
647+
atomicAdd(&inPos[hId * inImgW + wId], h1lambda * w1lambda * outPos[0]);
648+
}
649+
}
650+
651+
void hl_bilinear_backward(real* inGrad,
652+
const size_t inImgH,
653+
const size_t inImgW,
654+
const size_t inputH,
655+
const size_t inputW,
656+
const real* outGrad,
657+
const size_t outImgH,
658+
const size_t outImgW,
659+
const size_t outputH,
660+
const size_t outputW,
661+
const size_t numChannels,
662+
const real ratioH,
663+
const real ratioW) {
664+
int threadNum = outputH * outputW;
665+
int blocks = (threadNum + 1024 - 1) / 1024;
666+
667+
KeBilinearInterpBw<<< blocks, 1024, 0, STREAM_DEFAULT>>>(
668+
inGrad, inImgH, inImgW, inputH, inputW, outGrad, outImgH,
669+
outImgW, outputH, outputW, numChannels, ratioH, ratioW);
670+
CHECK_SYNC("hl_bilinear_backward failed");
671+
}
672+
541673
__global__ void maxoutFpCompute(size_t nthreads, const real * inData,
542674
real * outData, int* idData,
543675
size_t size, size_t featLen, size_t groups) {
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "BilinearInterpLayer.h"
16+
#include "paddle/utils/Logging.h"
17+
#include "paddle/utils/Stat.h"
18+
19+
namespace paddle {
20+
21+
REGISTER_LAYER(bilinear_interp, BilinearInterpLayer);
22+
23+
size_t BilinearInterpLayer::getSize() {
24+
inImgH_ = inputLayers_[0]->getOutput().getFrameHeight();
25+
inImgW_ = inputLayers_[0]->getOutput().getFrameWidth();
26+
27+
const BilinearInterpConfig& conf = config_.inputs(0).bilinear_interp_conf();
28+
if (inImgH_ == 0) {
29+
inImgH_ = conf.img_size_y();
30+
}
31+
if (inImgW_ == 0) {
32+
inImgW_ = conf.img_size_x();
33+
}
34+
35+
outImgH_ = conf.out_size_y();
36+
outImgW_ = conf.out_size_x();
37+
numChannels_ = conf.num_channels();
38+
39+
CHECK(outImgH_ > 0 && outImgW_ > 0);
40+
CHECK(inImgH_ > 0 && inImgW_ > 0);
41+
CHECK(numChannels_);
42+
43+
ratioH_ = (outImgH_ > 1) ?
44+
static_cast<real>(inImgH_ - 1) / (outImgH_ - 1) : 0.f;
45+
ratioW_ = (outImgW_ > 1) ?
46+
static_cast<real>(inImgW_ - 1) / (outImgW_ - 1) : 0.f;
47+
48+
getOutput().setFrameHeight(outImgH_);
49+
getOutput().setFrameWidth(outImgW_);
50+
return outImgH_ * outImgW_ * numChannels_;
51+
}
52+
53+
bool BilinearInterpLayer::init(const LayerMap& layerMap,
54+
const ParameterMap& parameterMap) {
55+
/* Initialize the basic parent class */
56+
Layer::init(layerMap, parameterMap);
57+
58+
CHECK_EQ(1, config_.inputs_size());
59+
60+
return true;
61+
}
62+
63+
void BilinearInterpLayer::forward(PassType passType) {
64+
Layer::forward(passType);
65+
66+
size_t batchSize = getInput(0).getBatchSize();
67+
size_t size = getSize();
68+
{
69+
REGISTER_TIMER_INFO("FwResetTimer", getName().c_str());
70+
resetOutput(batchSize, size);
71+
}
72+
73+
MatrixPtr inV = getInputValue(0);
74+
MatrixPtr outV = getOutputValue();
75+
{
76+
REGISTER_TIMER_INFO("FwBilinearInterpTimer", getName().c_str());
77+
outV->bilinearForward(*inV, inImgH_, inImgW_, outImgH_, outImgW_,
78+
numChannels_, ratioH_, ratioW_);
79+
}
80+
}
81+
82+
void BilinearInterpLayer::backward(const UpdateCallback& callback) {
83+
(void) callback;
84+
85+
MatrixPtr inputG = getInputGrad(0);
86+
MatrixPtr outG = getOutputGrad();
87+
{
88+
REGISTER_TIMER_INFO("BwBilinearInterpTimer", getName().c_str());
89+
if (inputG) {
90+
inputG->bilinearBackward(*outG, outImgH_, outImgW_, inImgH_, inImgW_,
91+
numChannels_, ratioH_, ratioW_);
92+
}
93+
}
94+
}
95+
} // namespace paddle

0 commit comments

Comments
 (0)