Skip to content

Commit d34780e

Browse files
committed
fix issue for resnet
1 parent 2f3665e commit d34780e

File tree

2 files changed

+7
-13
lines changed

2 files changed

+7
-13
lines changed

paddle/gserver/layers/MKLDNNFcLayer.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,18 +60,16 @@ void MKLDNNFcLayer::convertWeightsFromPaddle() {
6060
}
6161

6262
CHECK(wgtVal_) << "should have been initialized";
63-
bool hasNoSpatial_ = ih_ == 1 && iw_ == 1;
6463
auto targetDim = wgtVal_->getDims();
65-
auto srcFmt = hasNoSpatial_ ? format::io : format::ihwo;
64+
auto srcFmt = targetDim.size() == 2 ? format::io : format::ihwo;
6665
wgtVal_->reorderDataFrom(wgtVal_, srcFmt, targetDim);
6766
hasInitedWgt_ = true;
6867
}
6968

7069
void MKLDNNFcLayer::convertWeightsToPaddle() {
7170
CHECK(wgtVal_) << "should have been initialized";
72-
bool hasNoSpatial_ = ih_ == 1 && iw_ == 1;
7371
auto targetDim = wgtVal_->getDims();
74-
auto dstFmt = hasNoSpatial_ ? format::io : format::ihwo;
72+
auto dstFmt = targetDim.size() == 2 ? format::io : format::ihwo;
7573
wgtVal_->reorderDataTo(wgtVal_, dstFmt, targetDim);
7674
}
7775

paddle/gserver/layers/MKLDNNLayer.cpp

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -181,21 +181,17 @@ void MKLDNNLayer::resetInValue(
181181
auto extPD = MKLDNNMatrix::createPrimitiveDesc(
182182
{bs_, ic_, ih_, iw_}, format::nchw, engine_);
183183
const MatrixPtr& inMat = inputLayers_[inputIdx]->getOutputValue();
184-
in = std::dynamic_pointer_cast<MKLDNNMatrix>(inMat);
185-
CHECK_EQ(inputIsOnlyMKLDNN(), in != nullptr);
186-
if (in == nullptr || in->getFormat() == format::nc) {
187-
in = MKLDNNMatrix::create(extPD, inMat);
188-
}
189-
extInVal_ = isPaddleFormat(in->getFormat()) ? in : nullptr;
190-
if (in->getFormat() == format::nc) {
191-
CHECK(ih_ == 1 && iw_ == 1);
184+
extInVal_ = std::dynamic_pointer_cast<MKLDNNMatrix>(inMat);
185+
CHECK_EQ(inputIsOnlyMKLDNN(), extInVal_ != nullptr);
186+
if (extInVal_ == nullptr || extInVal_->getFormat() == format::nc) {
187+
extInVal_ = MKLDNNMatrix::create(extPD, inMat);
192188
}
189+
in = extInVal_;
193190
if (nullptr == intPD || in->getPrimitiveDesc() == *intPD) {
194191
return;
195192
}
196193
// need create reorder
197194
in = MKLDNNMatrix::create(*intPD);
198-
extInVal_ = extInVal_ ? extInVal_ : MKLDNNMatrix::create(extPD, inMat);
199195
cvtInVal_ = MKLDNNMatrix::createReorder(extInVal_, in);
200196
CHECK(cvtInVal_) << "should not be emptry";
201197
}

0 commit comments

Comments
 (0)