Skip to content

Commit 70e0468

Browse files
committed
add getSize method for PoolProjection
1 parent bdc9d10 commit 70e0468

File tree

4 files changed

+24
-6
lines changed

4 files changed

+24
-6
lines changed

paddle/gserver/layers/PoolProjection.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ PoolProjection* PoolProjection::create(const ProjectionConfig& config,
3232
}
3333

3434
void MaxPoolProjection::forward() {
35+
size_t width = getSize();
36+
CHECK_EQ(width, out_->value->getWidth());
3537
MatrixPtr inputV = in_->value;
3638
MatrixPtr outV = out_->value;
3739
outV->maxPoolForward(*inputV, imgSizeY_, imgSize_, channels_, sizeX_, sizeY_,
@@ -55,6 +57,8 @@ void MaxPoolProjection::backward(const UpdateCallback& callback) {
5557
}
5658

5759
void AvgPoolProjection::forward() {
60+
size_t width = getSize();
61+
CHECK_EQ(width, out_->value->getWidth());
5862
MatrixPtr inputV = in_->value;
5963
MatrixPtr outV = out_->value;
6064
outV->avgPoolForward(*inputV, imgSizeY_, imgSize_, channels_, sizeX_, sizeY_,

paddle/gserver/layers/PoolProjection.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,26 @@ class PoolProjection : public Projection {
5151
static PoolProjection* create(const ProjectionConfig& config,
5252
ParameterPtr parameter, bool useGpu);
5353
const std::string& getPoolType() const { return poolType_; }
54+
size_t getSize() {
55+
imgSizeY_ = in_->getFrameHeight();
56+
imgSize_ = in_->getFrameWidth();
57+
const PoolConfig& conf = config_.pool_conf();
58+
if (imgSizeY_ == 0) {
59+
imgSizeY_ = conf.has_img_size_y() ? conf.img_size_y() : conf.img_size();
60+
}
61+
if (imgSize_ == 0) {
62+
imgSize_ = conf.img_size();
63+
}
64+
outputY_ = outputSize(imgSizeY_, sizeY_, confPaddingY_, strideY_,
65+
/* caffeMode */ false);
66+
outputX_ = outputSize(imgSize_, sizeX_, confPadding_, stride_,
67+
/* caffeMode */ false);
68+
69+
const_cast<Argument*>(out_)->setFrameHeight(outputY_);
70+
const_cast<Argument*>(out_)->setFrameWidth(outputX_);
71+
72+
return outputY_ * outputX_ * channels_;
73+
}
5474
};
5575

5676
class MaxPoolProjection : public PoolProjection {

paddle/gserver/layers/PoolProjectionLayer.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,6 @@ size_t PoolProjectionLayer::getSize() {
3838

3939
layerSize = outputH_ * outputW_ * channels_;
4040

41-
getOutput().setFrameHeight(outputH_);
42-
getOutput().setFrameWidth(outputW_);
4341
return layerSize;
4442
}
4543

paddle/gserver/layers/SpatialPyramidPoolLayer.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,6 @@ size_t SpatialPyramidPoolLayer::getSize() {
7070
size_t outputW = (std::pow(4, pyramidHeight_) - 1) / (4 - 1);
7171

7272
layerSize = outputH * outputW * channels_;
73-
74-
getOutput().setFrameHeight(outputH);
75-
getOutput().setFrameWidth(outputW);
76-
7773
return layerSize;
7874
}
7975

0 commit comments

Comments
 (0)