Skip to content

Commit ae7452f

Browse files
committed
Merge branch 'develop' of github.com:baidu/Paddle into feature/fix_pydataprovider_multiple_obj_bugs
2 parents 33b8164 + ca0bb40 commit ae7452f

31 files changed

+1231
-351
lines changed

doc/source/gserver/layers/layer.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -465,6 +465,11 @@ SumOfSquaresCostLayer
465465
.. doxygenclass:: paddle::SumOfSquaresCostLayer
466466
:members:
467467

468+
SumCostLayer
469+
`````````````````````
470+
.. doxygenclass:: paddle::SumCostLayer
471+
:members:
472+
468473
CosSimLayer
469474
-----------
470475
.. doxygenclass:: paddle::CosSimLayer

doc/ui/api/trainer_config_helpers/layers.rst

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,12 @@ conv_operator
4646
:members: conv_operator
4747
:noindex:
4848

49+
conv_projection
50+
-------------
51+
.. automodule:: paddle.trainer_config_helpers.layers
52+
:members: conv_projection
53+
:noindex:
54+
4955
conv_shift_layer
5056
------------------
5157
.. automodule:: paddle.trainer_config_helpers.layers
@@ -71,6 +77,12 @@ img_pool_layer
7177
--------------
7278
.. automodule:: paddle.trainer_config_helpers.layers
7379
:members: img_pool_layer
80+
:noindex:
81+
82+
spp_layer
83+
--------------
84+
.. automodule:: paddle.trainer_config_helpers.layers
85+
:members: spp_layer
7486
:noindex:
7587

7688
maxout_layer
@@ -254,6 +266,12 @@ expand_layer
254266
:members: expand_layer
255267
:noindex:
256268

269+
repeat_layer
270+
------------
271+
.. automodule:: paddle.trainer_config_helpers.layers
272+
:members: repeat_layer
273+
:noindex:
274+
257275
Math Layers
258276
===========
259277

@@ -401,6 +419,12 @@ hsigmoid
401419
:members: hsigmoid
402420
:noindex:
403421

422+
sum_cost
423+
---------
424+
.. automodule:: paddle.trainer_config_helpers.layers
425+
:members: sum_cost
426+
:noindex:
427+
404428
Check Layer
405429
============
406430

paddle/cuda/include/hl_cnn.h

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ extern void hl_expand_feature2col(
9191
* @param[in] paddingH padding height.
9292
* @param[in] paddingW padding width.
9393
* @param[out] tgtData output data.
94+
* @param[in] tgtStride stride between output data samples.
9495
*
9596
*/
9697
extern void hl_maxpool_forward(
@@ -100,7 +101,8 @@ extern void hl_maxpool_forward(
100101
const int pooledH, const int pooledW,
101102
const int sizeX, const int sizeY,
102103
const int strideH, const int strideW,
103-
const int paddingH, const int paddingW, real* tgtData);
104+
const int paddingH, const int paddingW,
105+
real* tgtData, const int tgtStride);
104106

105107
/**
106108
* @brief Maximum pool backward.
@@ -123,6 +125,7 @@ extern void hl_maxpool_forward(
123125
* @param[in] paddingH padding height.
124126
* @param[in] paddingW padding width.
125127
* @param[out] targetGrad output grad.
128+
* @param[in] outStride stride between output data samples.
126129
*
127130
*/
128131
extern void hl_maxpool_backward(
@@ -135,7 +138,7 @@ extern void hl_maxpool_backward(
135138
const int strideH, const int strideW,
136139
const int paddingH, const int paddingW,
137140
real scaleA, real scaleB,
138-
real* targetGrad);
141+
real* targetGrad, const int outStride);
139142

140143
/**
141144
* @brief Averge pool forward.
@@ -154,6 +157,7 @@ extern void hl_maxpool_backward(
154157
* @param[in] paddingH padding height.
155158
* @param[in] paddingW padding width.
156159
* @param[out] tgtData output data.
160+
* @param[in] tgtStride stride between output data samples.
157161
*
158162
*/
159163
extern void hl_avgpool_forward(
@@ -163,7 +167,8 @@ extern void hl_avgpool_forward(
163167
const int pooledH, const int pooledW,
164168
const int sizeX, const int sizeY,
165169
const int strideH, const int strideW,
166-
const int paddingH, const int paddingW, real* tgtData);
170+
const int paddingH, const int paddingW,
171+
real* tgtData, const int tgtStride);
167172

168173
/**
169174
* @brief Maximum pool backward.
@@ -184,6 +189,7 @@ extern void hl_avgpool_forward(
184189
* @param[in] scaleA scale.
185190
* @param[in] scaleB scale.
186191
* @param[out] backGrad output grad.
192+
* @param[in] outStride stride between output data samples.
187193
*
188194
*/
189195
extern void hl_avgpool_backward(
@@ -195,7 +201,7 @@ extern void hl_avgpool_backward(
195201
const int strideH, const int strideW,
196202
int paddingH, int paddingW,
197203
real scaleA, real scaleB,
198-
real* backGrad);
204+
real* backGrad, const int outStride);
199205

200206
/**
201207
* @brief Cross-map-respose normalize forward.

paddle/cuda/include/stub/hl_cnn_stub.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@ inline void hl_maxpool_forward(
4444
const int pooledH, const int pooledW,
4545
const int sizeX, const int sizeY,
4646
const int strideH, const int strideW,
47-
const int paddingH, const int paddingW, real* tgtData) {}
47+
const int paddingH, const int paddingW,
48+
real* tgtData, const int tgtStride) {}
4849

4950
inline void hl_maxpool_backward(
5051
const int frameCnt, const real* inputData,
@@ -56,7 +57,7 @@ inline void hl_maxpool_backward(
5657
const int strideH, const int strideW,
5758
const int paddingH, const int paddingW,
5859
real scaleA, real scaleB,
59-
real* targetGrad) {}
60+
real* targetGrad, const int outStride) {}
6061

6162
inline void hl_avgpool_forward(
6263
const int frameCnt, const real* inputData,
@@ -65,7 +66,8 @@ inline void hl_avgpool_forward(
6566
const int pooledH, const int pooledW,
6667
const int sizeX, const int sizeY,
6768
const int strideH, const int strideW,
68-
const int paddingH, const int paddingW, real* tgtData) {}
69+
const int paddingH, const int paddingW,
70+
real* tgtData, const int tgtStride) {}
6971

7072
inline void hl_avgpool_backward(
7173
const int frameCnt, const real* outGrad,
@@ -76,7 +78,7 @@ inline void hl_avgpool_backward(
7678
const int strideH, const int strideW,
7779
int paddingH, int paddingW,
7880
real scaleA, real scaleB,
79-
real* backGrad) {}
81+
real* backGrad, const int outStride) {}
8082

8183
inline void hl_CMRNorm_forward(
8284
size_t frameCnt, const real* in, real* scale, real* out,

paddle/cuda/src/hl_cuda_cnn.cu

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ __global__ void KeMaxPoolForward(const int nthreads, const real* inputData,
152152
const int ksizeW, const int ksizeH,
153153
const int strideH, const int strideW,
154154
const int offsetH, const int offsetW,
155-
real* tgtData) {
155+
real* tgtData, const int tgtStride) {
156156
int index = blockIdx.x * blockDim.x + threadIdx.x;
157157
if (index < nthreads) {
158158
int pw = index % pooledW;
@@ -173,7 +173,9 @@ __global__ void KeMaxPoolForward(const int nthreads, const real* inputData,
173173
maxval = inputData[h * width + w];
174174
}
175175
}
176-
tgtData[index] = maxval;
176+
int tgtIndex = index % (pooledW * pooledH * channels) +
177+
frameNum * tgtStride;
178+
tgtData[tgtIndex] = maxval;
177179
}
178180
}
179181

@@ -184,7 +186,7 @@ void hl_maxpool_forward(const int frameCnt, const real* inputData,
184186
const int sizeX, const int sizeY,
185187
const int strideH, const int strideW,
186188
const int paddingH, const int paddingW,
187-
real* tgtData) {
189+
real* tgtData, const int tgtStride) {
188190

189191
int num_kernels = pooledH * pooledW * channels * frameCnt;
190192
int blocks = (num_kernels + 1024 - 1) / 1024;
@@ -194,7 +196,7 @@ void hl_maxpool_forward(const int frameCnt, const real* inputData,
194196
KeMaxPoolForward<<< grid, threads, 0, STREAM_DEFAULT >>>
195197
(num_kernels, inputData, channels, height, width,
196198
pooledH, pooledW, sizeX, sizeY, strideH, strideW,
197-
paddingH, paddingW, tgtData);
199+
paddingH, paddingW, tgtData, tgtStride);
198200
CHECK_SYNC("hl_maxpool_forward failed");
199201
}
200202

@@ -207,7 +209,7 @@ __global__ void KeMaxPoolBackward(const int nthreads, const real* inputData,
207209
const int strideH, const int strideW,
208210
const int padH, const int padW,
209211
real scaleA, real scaleB,
210-
real* targetGrad) {
212+
real* targetGrad, const int outStride) {
211213
int index = blockIdx.x * blockDim.x + threadIdx.x;
212214
if (index < nthreads) {
213215
// find out the local index
@@ -223,8 +225,8 @@ __global__ void KeMaxPoolBackward(const int nthreads, const real* inputData,
223225
int pwend = offsetW >= 0 ? min(offsetW / strideW + 1, pooledW) : 0;
224226
real gradient = 0;
225227
real input = inputData[index];
226-
outData += (frameNum * channels + offsetC) * pooledH * pooledW;
227-
outGrad += (frameNum * channels + offsetC) * pooledH * pooledW;
228+
outData += (frameNum * outStride + offsetC * pooledH * pooledW);
229+
outGrad += (frameNum * outStride + offsetC * pooledH * pooledW);
228230
for (int ph = phstart; ph < phend; ++ph) {
229231
for (int pw = pwstart; pw < pwend; ++pw) {
230232
if (input == outData[ph * pooledW + pw]) {
@@ -246,7 +248,7 @@ void hl_maxpool_backward(const int frameCnt, const real* inputData,
246248
const int strideH, const int strideW,
247249
const int paddingH, const int paddingW,
248250
real scaleA, real scaleB,
249-
real* targetGrad) {
251+
real* targetGrad, const int outStride) {
250252

251253
int num_kernels = height * width * channels * frameCnt;
252254
int blocks = (num_kernels + 1024 - 1) / 1024;
@@ -257,7 +259,7 @@ void hl_maxpool_backward(const int frameCnt, const real* inputData,
257259
strideH, strideW,
258260
paddingH, paddingW,
259261
scaleA, scaleB,
260-
targetGrad);
262+
targetGrad, outStride);
261263
CHECK_SYNC("hl_maxpool_backward");
262264
}
263265

@@ -268,7 +270,7 @@ __global__ void KeAvgPoolForward(const int nthreads, const real* inputData,
268270
const int sizeX, const int sizeY,
269271
const int strideH, const int strideW,
270272
const int padH, const int padW,
271-
real* tgtData) {
273+
real* tgtData, const int tgtStride) {
272274
int index = blockIdx.x * blockDim.x + threadIdx.x;
273275
if (index < nthreads) {
274276
int pw = index % pooledW;
@@ -293,7 +295,9 @@ __global__ void KeAvgPoolForward(const int nthreads, const real* inputData,
293295
aveval += inputData[h * width + w];
294296
}
295297
}
296-
tgtData[index] = aveval / pool_size;
298+
int tgtIndex = index % (pooledW * pooledH * channels) +
299+
frameNum * tgtStride;
300+
tgtData[tgtIndex] = aveval / pool_size;
297301
}
298302
}
299303

@@ -303,14 +307,15 @@ void hl_avgpool_forward(const int frameCnt, const real* inputData,
303307
const int pooledH, const int pooledW,
304308
const int sizeX, const int sizeY,
305309
const int strideH, const int strideW,
306-
const int paddingH, const int paddingW, real* tgtData) {
310+
const int paddingH, const int paddingW,
311+
real* tgtData, const int tgtStride) {
307312
int num_kernels = pooledH * pooledW * channels * frameCnt;
308313
int blocks = (num_kernels + 1024 - 1) / 1024;
309314
KeAvgPoolForward<<< blocks, 1024, 0, STREAM_DEFAULT >>>
310315
(num_kernels, inputData, channels,
311316
height, width, pooledH, pooledW,
312317
sizeX, sizeY, strideH, strideW,
313-
paddingH, paddingW, tgtData);
318+
paddingH, paddingW, tgtData, tgtStride);
314319
CHECK_SYNC("hl_avgpool_forward failed");
315320
}
316321

@@ -322,7 +327,7 @@ __global__ void KeAvgPoolBackward(const int nthreads, const real* outGrad,
322327
const int strideH, const int strideW,
323328
const int padH, const int padW,
324329
real scaleA, real scaleB,
325-
real* tgtGrad) {
330+
real* tgtGrad, const int outStride) {
326331
int index = blockIdx.x * blockDim.x + threadIdx.x;
327332
if (index < nthreads) {
328333
int offsetW = index % width + padW;
@@ -335,7 +340,8 @@ __global__ void KeAvgPoolBackward(const int nthreads, const real* outGrad,
335340
int phend = offsetH >= 0 ? min(offsetH / strideH + 1, pooledH) : 0;
336341
int pwend = offsetW >= 0 ? min(offsetW / strideW + 1, pooledW) : 0;
337342
real gradient = 0;
338-
outGrad += (frameNum * channels + offsetC) * pooledH * pooledW;
343+
outGrad += (frameNum * outStride + offsetC * pooledH * pooledW);
344+
339345

340346
for (int ph = phstart; ph < phend; ++ph) {
341347
for (int pw = pwstart; pw < pwend; ++pw) {
@@ -360,7 +366,7 @@ void hl_avgpool_backward(const int frameCnt, const real* outGrad,
360366
const int strideH, const int strideW,
361367
const int paddingH, const int paddingW,
362368
real scaleA, real scaleB,
363-
real* backGrad) {
369+
real* backGrad, const int outStride) {
364370
int num_kernels = height * width * channels * frameCnt;
365371
int blocks = (num_kernels + 1024 - 1) / 1024;
366372

@@ -370,7 +376,7 @@ void hl_avgpool_backward(const int frameCnt, const real* outGrad,
370376
strideH, strideW,
371377
paddingH, paddingW,
372378
scaleA, scaleB,
373-
backGrad);
379+
backGrad, outStride);
374380
CHECK_SYNC("hl_avgpool_backward failed");
375381
}
376382

paddle/gserver/layers/CostLayer.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -562,4 +562,39 @@ void HuberTwoClass::backwardImpIn(
562562
}
563563
}
564564

565+
/**
566+
* This cost layer compute the sum of its input as loss.
567+
* \f[
568+
* o(i) = \sum_{j=1}^D y_{ij}
569+
* \f]
570+
*/
571+
class SumCostLayer : public Layer {
572+
public:
573+
explicit SumCostLayer(const LayerConfig& config) : Layer(config) {}
574+
575+
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap) {
576+
bool ret = Layer::init(layerMap, parameterMap);
577+
if (!ret) return ret;
578+
CHECK_EQ(inputLayers_.size(), 1UL);
579+
return true;
580+
}
581+
582+
virtual void forward(PassType passType) {
583+
Layer::forward(passType);
584+
const MatrixPtr& input = getInputValue(0);
585+
586+
/* malloc memory for the output_ if necessary */
587+
int batchSize = input->getHeight();
588+
int size = 1;
589+
resizeOutput(batchSize, size);
590+
output_.value->sumRows(*input);
591+
}
592+
593+
virtual void backward(const UpdateCallback& callback = nullptr) {
594+
getInputGrad(0)->add((real)1);
595+
}
596+
};
597+
598+
REGISTER_LAYER(sum_cost, SumCostLayer);
599+
565600
} // namespace paddle

paddle/gserver/layers/CostLayer.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ class SoftBinaryClassCrossEntropy : public CostLayer {
129129
* This cost layer compute Euclidean (L2) loss for real-valued regression
130130
* tasks.
131131
* \f[
132-
* L = \frac{1}{2N} \sum_{i=1}^N {|| \hat{y}_i - y_i||_2^2}
132+
* L = \sum_{i=1}^N {|| \hat{y}_i - y_i||_2^2}
133133
* \f]
134134
*/
135135
class SumOfSquaresCostLayer : public CostLayer {

paddle/gserver/layers/PoolLayer.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,8 @@ bool PoolLayer::init(const LayerMap& layerMap,
5252
Layer* PoolLayer::create(const LayerConfig& config) {
5353
CHECK_EQ(config.inputs_size(), 1);
5454
const std::string& pool = config.inputs(0).pool_conf().pool_type();
55-
if (pool == "max-projection") {
56-
return new MaxPoolProjectionLayer(config);
57-
} else if (pool == "avg-projection") {
58-
return new AvgPoolProjectionLayer(config);
55+
if (pool == "max-projection" || pool == "avg-projection") {
56+
return new PoolProjectionLayer(config);
5957
#ifndef PADDLE_ONLY_CPU
6058
} else if (CudnnPoolLayer::typeCheck(pool)) {
6159
return new CudnnPoolLayer(config);

0 commit comments

Comments
 (0)