Skip to content

Commit 95da095

Browse files
committed
fix cudnn conv bug which occurs in image classfication demo in GTX GPU
1 parent 7eb29f2 commit 95da095

File tree

2 files changed

+12
-0
lines changed

2 files changed

+12
-0
lines changed

paddle/gserver/layers/CudnnConvLayer.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ bool CudnnConvLayer::init(const LayerMap &layerMap,
8585
biasOffset_ = numFilters_ / groups_[0];
8686
}
8787

88+
batchNum_ = 0;
8889
isSelectAlgo_ = false;
8990
return true;
9091
}
@@ -132,6 +133,9 @@ void CudnnConvLayer::reshape(int batchSize) {
132133
getOutput().setFrameHeight(outputH_);
133134
getOutput().setFrameWidth(outputW_);
134135

136+
isSelectAlgo_ = (batchSize == batchNum_);
137+
batchNum_ = batchSize;
138+
135139
size_t maxWorkSpace = 0;
136140
for (size_t i = 0; i < inputLayers_.size(); i++) {
137141
CHECK_EQ(inputLayers_[i]->getOutput().value->getWidth(),
@@ -160,6 +164,10 @@ void CudnnConvLayer::reshape(int batchSize) {
160164

161165
maxWorkSpace = std::max(fwdLimitBytes_[i], bwdDataLimitBytes_[i]);
162166
maxWorkSpace = std::max(maxWorkSpace, bwdFilterLimitBytes_[i]);
167+
168+
VLOG(3) << getName() << " Fwd / BwdData / BwdFilter algo: " << fwdAlgo_[i]
169+
<< " / " << bwdDataAlgo_[i]
170+
<< " / " << bwdFilterAlgo_[i];
163171
}
164172
}
165173

paddle/gserver/layers/CudnnConvLayer.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,10 @@ class CudnnConvLayer : public ConvBaseLayer {
8787
/// Is or not select conv algorihtm.
8888
bool isSelectAlgo_;
8989

90+
/// batchNum is used to record batch size. If the batch size is changed,
91+
/// the selection algorithm will be called.
92+
int batchNum_;
93+
9094
public:
9195
explicit CudnnConvLayer(const LayerConfig& config) : ConvBaseLayer(config) {}
9296

0 commit comments

Comments
 (0)