Skip to content

Commit 48a6168

Browse files
author
wangyang59
committed
following comments from qingqing01
1 parent b16c0a8 commit 48a6168

12 files changed

+135
-126
lines changed

paddle/gserver/layers/ConvBaseOperator.cpp

Lines changed: 16 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,22 @@ ConvBaseOperator::ConvBaseOperator(const OperatorConfig &config, bool useGpu)
4949
isSelectAlgo_ = false;
5050
}
5151

52-
void ConvBaseOperator::allocConvWorkSpace(size_t maxWorkSpace) {
52+
void ConvBaseOperator::allocConvWorkSpace() {
53+
hl_conv_workspace(imageDesc_,
54+
outputDesc_,
55+
filterDesc_,
56+
convDesc_,
57+
&fwdAlgo_,
58+
&fwdLimitBytes_,
59+
&bwdDataAlgo_,
60+
&bwdDataLimitBytes_,
61+
&bwdFilterAlgo_,
62+
&bwdFilterLimitBytes_);
63+
64+
size_t maxWorkSpace = 0;
65+
maxWorkSpace = std::max(fwdLimitBytes_, bwdDataLimitBytes_);
66+
maxWorkSpace = std::max(maxWorkSpace, bwdFilterLimitBytes_);
67+
5368
if (maxWorkSpace > workSpaceInBytes_) {
5469
if (workSpaceInBytes_ != 0) {
5570
hl_free_mem_device(workSpace_);
@@ -60,59 +75,6 @@ void ConvBaseOperator::allocConvWorkSpace(size_t maxWorkSpace) {
6075
}
6176
}
6277

63-
void ConvBaseOperator::reshape(int batchSize) {
64-
if (isDeconv_) {
65-
outputH_ = ins_[0]->getFrameHeight();
66-
outputW_ = ins_[0]->getFrameWidth();
67-
if (outputH_ == 0) outputH_ = outputY_;
68-
if (outputW_ == 0) outputW_ = outputX_;
69-
imageH_ =
70-
imageSize(outputH_, filterSizeY_, paddingY_, strideY_, caffeMode_);
71-
imageW_ = imageSize(outputW_, filterSize_, padding_, stride_, caffeMode_);
72-
/// Check that the imageSizes are consistent with config
73-
CHECK_EQ(imageH_, imgSizeY_);
74-
CHECK_EQ(imageW_, imgSize_);
75-
out_->setFrameHeight(imageH_);
76-
out_->setFrameWidth(imageW_);
77-
} else {
78-
imageH_ = ins_[0]->getFrameHeight();
79-
imageW_ = ins_[0]->getFrameWidth();
80-
if (imageH_ == 0) imageH_ = imgSizeY_;
81-
if (imageW_ == 0) imageW_ = imgSize_;
82-
outputH_ =
83-
outputSize(imageH_, filterSizeY_, paddingY_, strideY_, caffeMode_);
84-
outputW_ = outputSize(imageW_, filterSize_, padding_, stride_, caffeMode_);
85-
/// Check that the outputSizes are consistent with config
86-
CHECK_EQ(outputH_, outputY_);
87-
CHECK_EQ(outputW_, outputX_);
88-
out_->setFrameHeight(outputH_);
89-
out_->setFrameWidth(outputW_);
90-
}
91-
92-
reshapeImageDescriptors();
93-
94-
if (!isSelectAlgo_) {
95-
hl_conv_workspace(imageDesc_,
96-
outputDesc_,
97-
filterDesc_,
98-
convDesc_,
99-
&fwdAlgo_,
100-
&fwdLimitBytes_,
101-
&bwdDataAlgo_,
102-
&bwdDataLimitBytes_,
103-
&bwdFilterAlgo_,
104-
&bwdFilterLimitBytes_);
105-
106-
size_t maxWorkSpace = 0;
107-
maxWorkSpace = std::max(fwdLimitBytes_, bwdDataLimitBytes_);
108-
maxWorkSpace = std::max(maxWorkSpace, bwdFilterLimitBytes_);
109-
110-
allocConvWorkSpace(maxWorkSpace);
111-
}
112-
113-
isSelectAlgo_ = true;
114-
}
115-
11678
void ConvBaseOperator::computeConvSizes() {
11779
hl_create_filter_descriptor(
11880
&filterDesc_, channels_, numFilters_, filterSizeY_, filterSize_);
@@ -153,15 +115,6 @@ void ConvBaseOperator::reshapeImageDescriptors() {
153115
padding_,
154116
strideY_,
155117
stride_);
156-
157-
if (isDeconv_) {
158-
inputOffset_ = numFilters_ * outputH_ * outputW_;
159-
outputOffset_ = channels_ * imageH_ * imageW_;
160-
} else {
161-
inputOffset_ = channels_ * imageH_ * imageW_;
162-
outputOffset_ = numFilters_ * outputH_ * outputW_;
163-
}
164-
weightOffset_ = numFilters_ * channels_ * filterSize_ * filterSizeY_;
165118
}
166119

167120
void ConvBaseOperator::getConvParams() {

paddle/gserver/layers/ConvBaseOperator.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ class ConvBaseOperator : public Operator {
5656
/**
5757
* Allocate Gpu Memory for cudnn convolution algorithms.
5858
*/
59-
void allocConvWorkSpace(size_t maxWorkSpace);
59+
void allocConvWorkSpace();
6060

6161
/**
6262
* Create cudnn tensor descriptor for convolution operation.
@@ -71,7 +71,7 @@ class ConvBaseOperator : public Operator {
7171
/**
7272
* Reshape cudnn tensor descriptor.
7373
*/
74-
void reshape(int batchSize);
74+
virtual void reshape(int batchSize) = 0;
7575

7676
/**
7777
* Check filter size is equal to the size calculated by parameters from

paddle/gserver/layers/ConvBaseProjection.cpp

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -140,19 +140,7 @@ void ConvBaseProjection::reshapeTensorDesc(int batchSize) {
140140
void ConvBaseProjection::reshape(int batchSize) {
141141
size_t width = calOutputSize();
142142
CHECK_EQ(width, out_->value->getWidth());
143-
if (isDeconv_) {
144-
CHECK_EQ(static_cast<size_t>(configChannels_ * outputH_ * outputW_),
145-
in_->value->getWidth())
146-
<< "Wrong input size for convolution transpose"
147-
<< " channels=" << configChannels_ << " outputH=" << outputH_
148-
<< " outputW=" << outputW_ << " inputSize=" << in_->value->getWidth();
149-
} else {
150-
CHECK_EQ(static_cast<size_t>(configChannels_ * imageH_ * imageW_),
151-
in_->value->getWidth())
152-
<< "Wrong input size for convolution"
153-
<< " channels=" << configChannels_ << " imageH=" << imageH_
154-
<< " imageW=" << imageW_ << " inputSize=" << in_->value->getWidth();
155-
}
143+
CHECK_EQ(calInputSize(), in_->value->getWidth());
156144

157145
isSelectAlgo_ = (batchSize == batchNum_);
158146
batchNum_ = batchSize;

paddle/gserver/layers/ConvBaseProjection.h

Lines changed: 2 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -40,54 +40,8 @@ class ConvBaseProjection : public Projection {
4040
void reshapeTensorDesc(int batchSize);
4141
void reshape(int batchSize);
4242

43-
size_t calOutputSize() {
44-
if (isDeconv_) {
45-
outputH_ = in_->getFrameHeight();
46-
outputW_ = in_->getFrameWidth();
47-
if (outputH_ == 0) outputH_ = configOutH_;
48-
if (outputW_ == 0) outputW_ = configOutW_;
49-
imageH_ = imageSize(outputH_,
50-
filterH_,
51-
paddingH_,
52-
strideH_,
53-
/* caffeMode */ true);
54-
55-
imageW_ = imageSize(outputW_,
56-
filterW_,
57-
paddingW_,
58-
strideW_,
59-
/* caffeMode */ true);
60-
61-
const_cast<Argument*>(out_)->setFrameHeight(imageH_);
62-
const_cast<Argument*>(out_)->setFrameWidth(imageW_);
63-
64-
inputOffset_ = (configChannels_ / groups_) * outputH_ * outputW_;
65-
outputOffset_ = (configNumFilters_ / groups_) * imageH_ * imageW_;
66-
return imageH_ * imageW_ * configNumFilters_;
67-
} else {
68-
imageH_ = in_->getFrameHeight();
69-
imageW_ = in_->getFrameWidth();
70-
if (imageH_ == 0) imageH_ = configImgH_;
71-
if (imageW_ == 0) imageW_ = configImgW_;
72-
outputH_ = outputSize(imageH_,
73-
filterH_,
74-
paddingH_,
75-
strideH_,
76-
/* caffeMode */ true);
77-
outputW_ = outputSize(imageW_,
78-
filterW_,
79-
paddingW_,
80-
strideW_,
81-
/* caffeMode */ true);
82-
83-
const_cast<Argument*>(out_)->setFrameHeight(outputH_);
84-
const_cast<Argument*>(out_)->setFrameWidth(outputW_);
85-
86-
inputOffset_ = (configChannels_ / groups_) * imageH_ * imageW_;
87-
outputOffset_ = (configNumFilters_ / groups_) * outputH_ * outputW_;
88-
return outputH_ * outputW_ * configNumFilters_;
89-
}
90-
}
43+
virtual size_t calOutputSize() = 0;
44+
virtual size_t calInputSize() = 0;
9145

9246
static void* getSpaceBytes(size_t size);
9347

paddle/gserver/layers/ConvOperator.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,32 @@ namespace paddle {
2929

3030
REGISTER_OPERATOR(conv, ConvOperator);
3131

32+
void ConvOperator::reshape(int batchSize) {
33+
imageH_ = ins_[0]->getFrameHeight();
34+
imageW_ = ins_[0]->getFrameWidth();
35+
if (imageH_ == 0) imageH_ = imgSizeY_;
36+
if (imageW_ == 0) imageW_ = imgSize_;
37+
outputH_ = outputSize(imageH_, filterSizeY_, paddingY_, strideY_, caffeMode_);
38+
outputW_ = outputSize(imageW_, filterSize_, padding_, stride_, caffeMode_);
39+
/// Check that the outputSizes are consistent with config
40+
CHECK_EQ(outputH_, outputY_);
41+
CHECK_EQ(outputW_, outputX_);
42+
out_->setFrameHeight(outputH_);
43+
out_->setFrameWidth(outputW_);
44+
45+
reshapeImageDescriptors();
46+
47+
inputOffset_ = channels_ * imageH_ * imageW_;
48+
outputOffset_ = numFilters_ * outputH_ * outputW_;
49+
weightOffset_ = numFilters_ * channels_ * filterSize_ * filterSizeY_;
50+
51+
if (!isSelectAlgo_) {
52+
allocConvWorkSpace();
53+
}
54+
55+
isSelectAlgo_ = true;
56+
}
57+
3258
void ConvOperator::forward() {
3359
size_t batchSize = ins_[0]->value->getHeight();
3460
reshape(batchSize);

paddle/gserver/layers/ConvOperator.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ class ConvOperator : public ConvBaseOperator {
3838
virtual ~ConvOperator() {}
3939
void forward() override;
4040
void backward() override;
41+
void reshape(int batchSize) override;
4142
};
4243

4344
} // namespace paddle

paddle/gserver/layers/ConvProjection.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,34 @@ namespace paddle {
1919

2020
REGISTER_PROJECTION(conv, ConvProjection);
2121

22+
size_t ConvProjection::calOutputSize() {
23+
imageH_ = in_->getFrameHeight();
24+
imageW_ = in_->getFrameWidth();
25+
if (imageH_ == 0) imageH_ = configImgH_;
26+
if (imageW_ == 0) imageW_ = configImgW_;
27+
outputH_ = outputSize(imageH_,
28+
filterH_,
29+
paddingH_,
30+
strideH_,
31+
/* caffeMode */ true);
32+
outputW_ = outputSize(imageW_,
33+
filterW_,
34+
paddingW_,
35+
strideW_,
36+
/* caffeMode */ true);
37+
38+
const_cast<Argument *>(out_)->setFrameHeight(outputH_);
39+
const_cast<Argument *>(out_)->setFrameWidth(outputW_);
40+
41+
inputOffset_ = (configChannels_ / groups_) * imageH_ * imageW_;
42+
outputOffset_ = (configNumFilters_ / groups_) * outputH_ * outputW_;
43+
return outputH_ * outputW_ * configNumFilters_;
44+
}
45+
46+
size_t ConvProjection::calInputSize() {
47+
return static_cast<size_t>(configChannels_ * imageH_ * imageW_);
48+
}
49+
2250
void ConvProjection::forward() {
2351
int batchSize = in_->value->getHeight();
2452
reshape(batchSize);

paddle/gserver/layers/ConvProjection.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ class ConvProjection : public ConvBaseProjection {
3636

3737
virtual void forward();
3838
virtual void backward(const UpdateCallback& callback);
39+
virtual size_t calOutputSize();
40+
virtual size_t calInputSize();
3941
};
4042

4143
} // namespace paddle

paddle/gserver/layers/ConvTransOperator.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,32 @@ namespace paddle {
2929

3030
REGISTER_OPERATOR(convt, ConvTransOperator);
3131

32+
void ConvTransOperator::reshape(int batchSize) {
33+
outputH_ = ins_[0]->getFrameHeight();
34+
outputW_ = ins_[0]->getFrameWidth();
35+
if (outputH_ == 0) outputH_ = outputY_;
36+
if (outputW_ == 0) outputW_ = outputX_;
37+
imageH_ = imageSize(outputH_, filterSizeY_, paddingY_, strideY_, caffeMode_);
38+
imageW_ = imageSize(outputW_, filterSize_, padding_, stride_, caffeMode_);
39+
/// Check that the imageSizes are consistent with config
40+
CHECK_EQ(imageH_, imgSizeY_);
41+
CHECK_EQ(imageW_, imgSize_);
42+
out_->setFrameHeight(imageH_);
43+
out_->setFrameWidth(imageW_);
44+
45+
reshapeImageDescriptors();
46+
47+
inputOffset_ = numFilters_ * outputH_ * outputW_;
48+
outputOffset_ = channels_ * imageH_ * imageW_;
49+
weightOffset_ = numFilters_ * channels_ * filterSize_ * filterSizeY_;
50+
51+
if (!isSelectAlgo_) {
52+
allocConvWorkSpace();
53+
}
54+
55+
isSelectAlgo_ = true;
56+
}
57+
3258
void ConvTransOperator::forward() {
3359
size_t batchSize = ins_[0]->value->getHeight();
3460
reshape(batchSize);

paddle/gserver/layers/ConvTransOperator.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ class ConvTransOperator : public ConvBaseOperator {
3838
virtual ~ConvTransOperator() {}
3939
void forward() override;
4040
void backward() override;
41+
void reshape(int batchSize) override;
4142
};
4243

4344
} // namespace paddle

0 commit comments

Comments
 (0)