Skip to content

Commit 341486d

Browse files
authored
Merge pull request #107 from qingqing01/cudnn_conv
fix cudnn conv bug which occurs in image classfication demo in GTX GPU
2 parents 7eb29f2 + c1c07bb commit 341486d

File tree

2 files changed

+14
-0
lines changed

2 files changed

+14
-0
lines changed

paddle/gserver/layers/CudnnConvLayer.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ bool CudnnConvLayer::init(const LayerMap &layerMap,
8585
biasOffset_ = numFilters_ / groups_[0];
8686
}
8787

88+
batchNum_ = 0;
8889
isSelectAlgo_ = false;
8990
return true;
9091
}
@@ -132,6 +133,11 @@ void CudnnConvLayer::reshape(int batchSize) {
132133
getOutput().setFrameHeight(outputH_);
133134
getOutput().setFrameWidth(outputW_);
134135

136+
// if the batchSize remains the same, set isSelectAlgo_ true.
137+
// Otherwise, set isSelectAlgo_ false and select algo again.
138+
isSelectAlgo_ = (batchSize == batchNum_);
139+
batchNum_ = batchSize;
140+
135141
size_t maxWorkSpace = 0;
136142
for (size_t i = 0; i < inputLayers_.size(); i++) {
137143
CHECK_EQ(inputLayers_[i]->getOutput().value->getWidth(),
@@ -160,6 +166,10 @@ void CudnnConvLayer::reshape(int batchSize) {
160166

161167
maxWorkSpace = std::max(fwdLimitBytes_[i], bwdDataLimitBytes_[i]);
162168
maxWorkSpace = std::max(maxWorkSpace, bwdFilterLimitBytes_[i]);
169+
170+
VLOG(3) << getName() << " Fwd / BwdData / BwdFilter algo: " << fwdAlgo_[i]
171+
<< " / " << bwdDataAlgo_[i]
172+
<< " / " << bwdFilterAlgo_[i];
163173
}
164174
}
165175

paddle/gserver/layers/CudnnConvLayer.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,10 @@ class CudnnConvLayer : public ConvBaseLayer {
8787
/// Is or not select conv algorihtm.
8888
bool isSelectAlgo_;
8989

90+
/// batchNum is used to record batch size. If the batch size is changed,
91+
/// the selection algorithm will be called.
92+
int batchNum_;
93+
9094
public:
9195
explicit CudnnConvLayer(const LayerConfig& config) : ConvBaseLayer(config) {}
9296

0 commit comments

Comments
 (0)