@@ -109,19 +109,10 @@ void MKLDNNBatchNormLayer::convertWeightsFromPaddle() {
109
109
void MKLDNNBatchNormLayer::calMovingMeanAndVar () {
110
110
// calculating and saving moving mean and variance
111
111
CHECK_EQ (useGlobalStats_, false );
112
- MatrixPtr movingMean = movingMean_->getW ();
113
- MatrixPtr movingVar = movingVar_->getW ();
114
- if (FLAGS_trainer_count > 1 ) {
115
- auto mvMean = std::dynamic_pointer_cast<SharedCpuMatrix>(movingMean);
116
- auto mvVar = std::dynamic_pointer_cast<SharedCpuMatrix>(movingVar);
117
- CHECK (mvMean && mvVar);
118
- mvMean->add (*mean_, movingAvgFraction_, 1.0 - movingAvgFraction_);
119
- mvVar->add (*var_, movingAvgFraction_, 1.0 - movingAvgFraction_);
120
- } else {
121
- movingMean->add (*mean_, movingAvgFraction_, 1.0 - movingAvgFraction_);
122
- // here var is v^2
123
- movingVar->add (*var_, movingAvgFraction_, 1.0 - movingAvgFraction_);
124
- }
112
+ movingMean_->getW ()->add (
113
+ *mean_, movingAvgFraction_, 1.0 - movingAvgFraction_);
114
+ // here var is v^2
115
+ movingVar_->getW ()->add (*var_, movingAvgFraction_, 1.0 - movingAvgFraction_);
125
116
}
126
117
127
118
void MKLDNNBatchNormLayer::reshape (
@@ -142,8 +133,9 @@ void MKLDNNBatchNormLayer::resetFwd(std::vector<primitive>& pipeline,
142
133
MKLDNNMatrixPtr& wgt,
143
134
MKLDNNMatrixPtr& bias,
144
135
MKLDNNMatrixPtr& out) {
145
- // in training always calculate mean and var, so useGlobalStats must be false
146
- // in test depends on useGlobalStats
136
+ // In training phase, it will always calculate mean and var,
137
+ // so useGlobalStats must be false.
138
+ // In scoring phase, it depends on useGlobalStats choice.
147
139
if (passType_ != PASS_TEST && useGlobalStats_ == true ) {
148
140
LOG (WARNING) << " use_global_stats is invalid setting in training phase" ;
149
141
useGlobalStats_ = false ;
@@ -173,7 +165,7 @@ void MKLDNNBatchNormLayer::resetBwd(std::vector<primitive>& pipeline,
173
165
void MKLDNNBatchNormLayer::forward (PassType passType) {
174
166
MKLDNNLayer::forward (passType);
175
167
176
- // calculating and saving moving mean and variance
168
+ // calculate and save moving mean and variance
177
169
if (passType_ != PASS_TEST) {
178
170
calMovingMeanAndVar ();
179
171
}
0 commit comments