Skip to content

Commit 7039479

Browse files
committed
refine comment and code
1 parent 8845218 commit 7039479

File tree

2 files changed

+12
-18
lines changed

2 files changed

+12
-18
lines changed

paddle/gserver/layers/MKLDNNBatchNormLayer.cpp

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -109,19 +109,10 @@ void MKLDNNBatchNormLayer::convertWeightsFromPaddle() {
109109
void MKLDNNBatchNormLayer::calMovingMeanAndVar() {
110110
// calculating and saving moving mean and variance
111111
CHECK_EQ(useGlobalStats_, false);
112-
MatrixPtr movingMean = movingMean_->getW();
113-
MatrixPtr movingVar = movingVar_->getW();
114-
if (FLAGS_trainer_count > 1) {
115-
auto mvMean = std::dynamic_pointer_cast<SharedCpuMatrix>(movingMean);
116-
auto mvVar = std::dynamic_pointer_cast<SharedCpuMatrix>(movingVar);
117-
CHECK(mvMean && mvVar);
118-
mvMean->add(*mean_, movingAvgFraction_, 1.0 - movingAvgFraction_);
119-
mvVar->add(*var_, movingAvgFraction_, 1.0 - movingAvgFraction_);
120-
} else {
121-
movingMean->add(*mean_, movingAvgFraction_, 1.0 - movingAvgFraction_);
122-
// here var is v^2
123-
movingVar->add(*var_, movingAvgFraction_, 1.0 - movingAvgFraction_);
124-
}
112+
movingMean_->getW()->add(
113+
*mean_, movingAvgFraction_, 1.0 - movingAvgFraction_);
114+
// here var is v^2
115+
movingVar_->getW()->add(*var_, movingAvgFraction_, 1.0 - movingAvgFraction_);
125116
}
126117

127118
void MKLDNNBatchNormLayer::reshape(
@@ -142,8 +133,9 @@ void MKLDNNBatchNormLayer::resetFwd(std::vector<primitive>& pipeline,
142133
MKLDNNMatrixPtr& wgt,
143134
MKLDNNMatrixPtr& bias,
144135
MKLDNNMatrixPtr& out) {
145-
// in training always calculate mean and var, so useGlobalStats must be false
146-
// in test depends on useGlobalStats
136+
// In training phase, it will always calculate mean and var,
137+
// so useGlobalStats must be false.
138+
// In scoring phase, it depends on useGlobalStats choice.
147139
if (passType_ != PASS_TEST && useGlobalStats_ == true) {
148140
LOG(WARNING) << "use_global_stats is invalid setting in training phase";
149141
useGlobalStats_ = false;
@@ -173,7 +165,7 @@ void MKLDNNBatchNormLayer::resetBwd(std::vector<primitive>& pipeline,
173165
void MKLDNNBatchNormLayer::forward(PassType passType) {
174166
MKLDNNLayer::forward(passType);
175167

176-
// calculating and saving moving mean and variance
168+
// calculate and save moving mean and variance
177169
if (passType_ != PASS_TEST) {
178170
calMovingMeanAndVar();
179171
}

paddle/gserver/layers/MKLDNNBatchNormLayer.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,10 @@ class MKLDNNBatchNormLayer : public MKLDNNLayer {
5656
bool hasInitedWgt_;
5757

5858
// local mean and variance
59-
MKLDNNMatrixPtr mean_; // output of mkldnn: m
60-
MKLDNNMatrixPtr var_; // output of mkldnn: v^2
59+
// when useGlobalStats_ they are loaded from moving mean and variance
60+
// when do not useGlobalStats_ they are calculated from this mini-batch
61+
MKLDNNMatrixPtr mean_;
62+
MKLDNNMatrixPtr var_;
6163

6264
public:
6365
explicit MKLDNNBatchNormLayer(const LayerConfig& config)

0 commit comments

Comments
 (0)