Skip to content

Commit 64eaeba

Browse files
committed
enable mkldnn_batch_norm layer
1 parent 02fdf24 commit 64eaeba

File tree

2 files changed

+462
-0
lines changed

2 files changed

+462
-0
lines changed
Lines changed: 326 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,326 @@
1+
/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "MKLDNNBatchNormLayer.h"
16+
17+
using namespace mkldnn; // NOLINT
18+
typedef memory::format format;
19+
20+
namespace paddle {
21+
22+
REGISTER_LAYER(mkldnn_batch_norm, MKLDNNBatchNormLayer);
23+
24+
const real MKLDNNBatchNormLayer::EPS = 1E-5;
25+
26+
bool MKLDNNBatchNormLayer::init(const LayerMap& layerMap,
27+
const ParameterMap& parameterMap) {
28+
if (!MKLDNNLayer::init(layerMap, parameterMap)) {
29+
return false;
30+
}
31+
32+
// first one is input layer
33+
// the other two are created in config_parser.py saving moving mean and var
34+
CHECK_EQ(inputLayers_.size(), 3U);
35+
CHECK_EQ(inputLayers_.size(), parameters_.size());
36+
CHECK_EQ(inputLayers_.size(), size_t(config_.inputs_size()));
37+
38+
const ImageConfig& conf = config_.inputs(0).image_conf();
39+
ic_ = conf.channels();
40+
ih_ = inputLayers_[0]->getOutput().getFrameHeight();
41+
iw_ = inputLayers_[0]->getOutput().getFrameWidth();
42+
if (iw_ == 0 && ih_ == 0) {
43+
iw_ = conf.img_size();
44+
ih_ = conf.has_img_size_y() ? conf.img_size_y() : conf.img_size();
45+
}
46+
oc_ = ic_;
47+
oh_ = ih_;
48+
ow_ = iw_;
49+
if (config_.has_use_global_stats()) {
50+
useGlobalStats_ = config_.use_global_stats();
51+
}
52+
movingAvgFraction_ = config_.moving_average_fraction();
53+
VLOG(MKLDNN_BASE) << "--- " << (useGlobalStats_ ? "use" : "do not use")
54+
<< " --- global stats";
55+
VLOG(MKLDNN_BASE) << "Moving average fraction: " << movingAvgFraction_;
56+
57+
initWeight();
58+
movingMean_.reset(new Weight(oc_, 1, parameters_[1], 0));
59+
movingVar_.reset(new Weight(oc_, 1, parameters_[2], 0));
60+
return true;
61+
}
62+
63+
void MKLDNNBatchNormLayer::initWeight() {
64+
weight_.reset(new Weight(1, oc_, parameters_[0]));
65+
if (biasParameter_.get() != NULL) {
66+
biases_ = std::unique_ptr<Weight>(new Weight(1, oc_, biasParameter_));
67+
}
68+
CHECK_EQ(weight_ != nullptr, biases_ != nullptr)
69+
<< "only support have both weight and bias, or neither";
70+
if (weight_ && weight_->getW()) {
71+
CHECK(biases_ && biases_->getW());
72+
valueScaleShift_ = Matrix::create(2, oc_, false, false);
73+
valueScaleShift_->zeroMem();
74+
VectorPtr scale(new CpuVector(oc_, valueScaleShift_->getMemoryHandle(), 0));
75+
VectorPtr shift(
76+
new CpuVector(oc_, valueScaleShift_->getMemoryHandle(), oc_));
77+
const VectorPtr& wgt = parameters_[0]->getBuf(PARAMETER_VALUE);
78+
const VectorPtr& bias = biasParameter_->getBuf(PARAMETER_VALUE);
79+
scale->copyFrom(*wgt);
80+
shift->copyFrom(*bias);
81+
wgt->setData(valueScaleShift_->getData());
82+
bias->setData(valueScaleShift_->getData() + oc_);
83+
}
84+
if (weight_ && weight_->getWGrad()) {
85+
CHECK(biases_ && biases_->getWGrad());
86+
gradScaleShift_ = Matrix::create(2, oc_, false, false);
87+
gradScaleShift_->zeroMem();
88+
const VectorPtr& wgt = parameters_[0]->getBuf(PARAMETER_GRADIENT);
89+
const VectorPtr& bias = biasParameter_->getBuf(PARAMETER_GRADIENT);
90+
wgt->setData(gradScaleShift_->getData());
91+
bias->setData(gradScaleShift_->getData() + oc_);
92+
}
93+
}
94+
95+
void MKLDNNBatchNormLayer::convertWeightsFromPaddle() {
96+
if (hasInitedWgt_) {
97+
return;
98+
}
99+
// prepare mean and var if necessary
100+
if (useGlobalStats_) {
101+
CHECK(mean_);
102+
CHECK(var_);
103+
mean_->copyFrom(*(movingMean_->getW()));
104+
var_->copyFrom(*(movingVar_->getW()));
105+
}
106+
hasInitedWgt_ = true;
107+
}
108+
109+
void MKLDNNBatchNormLayer::calMovingMeanAndVar() {
110+
// calculating and saving moving mean and variance
111+
CHECK_EQ(useGlobalStats_, false);
112+
MatrixPtr movingMean = movingMean_->getW();
113+
MatrixPtr movingVar = movingVar_->getW();
114+
if (FLAGS_trainer_count > 1) {
115+
auto mvMean = std::dynamic_pointer_cast<SharedCpuMatrix>(movingMean);
116+
auto mvVar = std::dynamic_pointer_cast<SharedCpuMatrix>(movingVar);
117+
CHECK(mvMean && mvVar);
118+
mvMean->add(*mean_, movingAvgFraction_, 1.0 - movingAvgFraction_);
119+
mvVar->add(*var_, movingAvgFraction_, 1.0 - movingAvgFraction_);
120+
} else {
121+
movingMean->add(*mean_, movingAvgFraction_, 1.0 - movingAvgFraction_);
122+
// here var is v^2
123+
movingVar->add(*var_, movingAvgFraction_, 1.0 - movingAvgFraction_);
124+
}
125+
}
126+
127+
void MKLDNNBatchNormLayer::reshape(
128+
int& bs, int& ic, int& ih, int& iw, int oc, int& oh, int& ow) {
129+
reshapeInput(bs, ih, iw);
130+
oh = ih;
131+
ow = ow;
132+
// ic_ and oc can not be changed
133+
CHECK_EQ(inputElemenCnt_ / bs / ih / iw, (size_t)ic)
134+
<< "Input channel can not be changed";
135+
reshapeOutput(oh, ow);
136+
resizeOutput(bs, oc * oh * ow);
137+
printSizeInfo();
138+
}
139+
140+
void MKLDNNBatchNormLayer::resetFwd(std::vector<primitive>& pipeline,
141+
MKLDNNMatrixPtr& in,
142+
MKLDNNMatrixPtr& wgt,
143+
MKLDNNMatrixPtr& bias,
144+
MKLDNNMatrixPtr& out) {
145+
// in training always calculate mean and var, so useGlobalStats must be false
146+
// in test depends on useGlobalStats
147+
if (passType_ != PASS_TEST && useGlobalStats_ == true) {
148+
LOG(WARNING) << "use_global_stats is invalid setting in training phase";
149+
useGlobalStats_ = false;
150+
}
151+
152+
resetFwdBuffers(in, wgt, out);
153+
154+
resetFwdPD(fwdPD_, in, wgt, out);
155+
156+
resetFwdPipeline(pipeline, fwdPD_, in, wgt, out);
157+
}
158+
159+
void MKLDNNBatchNormLayer::resetBwd(std::vector<primitive>& pipeline,
160+
MKLDNNMatrixPtr& in,
161+
MKLDNNMatrixPtr& wgt,
162+
MKLDNNMatrixPtr& bias,
163+
MKLDNNMatrixPtr& out) {
164+
std::shared_ptr<bn_bwd::primitive_desc> pd;
165+
166+
resetBwdBuffers(in, wgt, out);
167+
168+
resetBwdPD(pd, in, wgt, out);
169+
170+
resetBwdPipeline(pipeline, pd, in, wgt, out);
171+
}
172+
173+
void MKLDNNBatchNormLayer::forward(PassType passType) {
174+
MKLDNNLayer::forward(passType);
175+
176+
// calculating and saving moving mean and variance
177+
if (passType_ != PASS_TEST) {
178+
calMovingMeanAndVar();
179+
}
180+
}
181+
182+
void MKLDNNBatchNormLayer::updateWeights(const UpdateCallback& callback) {
183+
weight_->getParameterPtr()->incUpdate(callback);
184+
if (biases_ && biases_->getWGrad()) {
185+
biases_->getParameterPtr()->incUpdate(callback);
186+
}
187+
}
188+
189+
void MKLDNNBatchNormLayer::resetFwdBuffers(MKLDNNMatrixPtr& in,
190+
MKLDNNMatrixPtr& wgt,
191+
MKLDNNMatrixPtr& out) {
192+
resetInValue(in);
193+
194+
memory::dims outDims = memory::dims{bs_, oc_, oh_, ow_};
195+
CHECK(in);
196+
auto outPD =
197+
MKLDNNMatrix::createPrimitiveDesc(outDims, in->getFormat(), engine_);
198+
resetOutValue(out, outPD);
199+
200+
if (valueScaleShift_) {
201+
auto pd = MKLDNNMatrix::createPrimitiveDesc({2, oc_}, format::nc, engine_);
202+
resetWithMatrix(wgt, valueScaleShift_, pd);
203+
}
204+
if (passType_ != PASS_TEST || useGlobalStats_) {
205+
auto pd = MKLDNNMatrix::createPrimitiveDesc({oc_}, format::x, engine_);
206+
mean_ = MKLDNNMatrix::create(pd);
207+
var_ = MKLDNNMatrix::create(pd);
208+
}
209+
}
210+
211+
void MKLDNNBatchNormLayer::resetFwdPD(
212+
std::shared_ptr<bn_fwd::primitive_desc>& pd,
213+
MKLDNNMatrixPtr in,
214+
MKLDNNMatrixPtr wgt,
215+
MKLDNNMatrixPtr out) {
216+
flags_ = 0u;
217+
prop_kind pk = passType_ == PASS_TEST ? prop_kind::forward_scoring
218+
: prop_kind::forward_training;
219+
if (useGlobalStats_) {
220+
flags_ = (flags_ | batch_normalization_flag::use_global_stats);
221+
}
222+
if (wgt) {
223+
flags_ = (flags_ | batch_normalization_flag::use_scale_shift);
224+
}
225+
auto fwdDesc = bn_fwd::desc(pk, in->getMemoryDesc(), EPS, flags_);
226+
pd.reset(new bn_fwd::primitive_desc(fwdDesc, engine_));
227+
// TODO(TJ): use check macro
228+
CHECK(out);
229+
CHECK(out->getPrimitiveDesc() == pd->dst_primitive_desc());
230+
if (wgt) {
231+
CHECK(wgt->getPrimitiveDesc() == pd->weights_primitive_desc());
232+
}
233+
if (passType_ != PASS_TEST || useGlobalStats_) {
234+
CHECK(mean_);
235+
CHECK(mean_->getPrimitiveDesc() == pd->mean_primitive_desc());
236+
CHECK(var_);
237+
CHECK(var_->getPrimitiveDesc() == pd->variance_primitive_desc());
238+
}
239+
}
240+
241+
void MKLDNNBatchNormLayer::resetFwdPipeline(
242+
std::vector<primitive>& pipeline,
243+
std::shared_ptr<bn_fwd::primitive_desc>& pd,
244+
MKLDNNMatrixPtr& in,
245+
MKLDNNMatrixPtr& wgt,
246+
MKLDNNMatrixPtr& out) {
247+
if (passType_ == PASS_TEST) {
248+
if (useGlobalStats_) {
249+
fwd_.reset(wgt != nullptr ? new bn_fwd(*pd,
250+
*in,
251+
(const primitive::at)(*mean_),
252+
(const primitive::at)(*var_),
253+
*wgt,
254+
*out)
255+
: new bn_fwd(*pd,
256+
*in,
257+
(const primitive::at)(*mean_),
258+
(const primitive::at)(*var_),
259+
*out));
260+
} else {
261+
fwd_.reset(wgt != nullptr ? new bn_fwd(*pd, *in, *wgt, *out)
262+
: new bn_fwd(*pd, *in, *out));
263+
}
264+
} else {
265+
CHECK_EQ(useGlobalStats_, false)
266+
<< "useGlobalStats should be false in training";
267+
fwd_.reset(wgt != nullptr ? new bn_fwd(*pd, *in, *wgt, *out, *mean_, *var_)
268+
: new bn_fwd(*pd, *in, *out, *mean_, *var_));
269+
}
270+
pipeline.push_back(*fwd_);
271+
}
272+
273+
void MKLDNNBatchNormLayer::resetBwdBuffers(MKLDNNMatrixPtr& in,
274+
MKLDNNMatrixPtr& wgt,
275+
MKLDNNMatrixPtr& out) {
276+
CHECK(inVal_ && outVal_);
277+
resetOutGrad(out, outVal_->getPrimitiveDesc());
278+
resetInGrad(in, inVal_->getPrimitiveDesc());
279+
if (gradScaleShift_) {
280+
CHECK(wgtVal_);
281+
resetWithMatrix(wgt, gradScaleShift_, wgtVal_->getPrimitiveDesc());
282+
}
283+
}
284+
285+
void MKLDNNBatchNormLayer::resetBwdPD(
286+
std::shared_ptr<bn_bwd::primitive_desc>& pd,
287+
MKLDNNMatrixPtr& in,
288+
MKLDNNMatrixPtr& wgt,
289+
MKLDNNMatrixPtr& out) {
290+
pd = nullptr;
291+
if (in == nullptr) {
292+
return;
293+
}
294+
CHECK(out);
295+
CHECK(out->getPrimitiveDesc() == in->getPrimitiveDesc());
296+
auto md = in->getMemoryDesc();
297+
auto bwdDesc = bn_bwd::desc(prop_kind::backward, md, md, EPS, flags_);
298+
pd.reset(new bn_bwd::primitive_desc(bwdDesc, engine_, *fwdPD_));
299+
// TODO(TJ): use check macro
300+
CHECK(wgt);
301+
CHECK(wgt->getPrimitiveDesc() == pd->diff_weights_primitive_desc());
302+
CHECK(pd->weights_primitive_desc() == fwdPD_->weights_primitive_desc());
303+
CHECK(mean_);
304+
CHECK(mean_->getPrimitiveDesc() == pd->mean_primitive_desc());
305+
CHECK(var_);
306+
CHECK(var_->getPrimitiveDesc() == pd->variance_primitive_desc());
307+
}
308+
309+
void MKLDNNBatchNormLayer::resetBwdPipeline(
310+
std::vector<primitive>& pipeline,
311+
std::shared_ptr<bn_bwd::primitive_desc>& pd,
312+
MKLDNNMatrixPtr& in,
313+
MKLDNNMatrixPtr& wgt,
314+
MKLDNNMatrixPtr& out) {
315+
if (pd == nullptr) {
316+
return;
317+
}
318+
CHECK(inVal_);
319+
bwdData_.reset(
320+
wgt && wgtVal_
321+
? new bn_bwd(*pd, *inVal_, *mean_, *var_, *out, *wgtVal_, *in, *wgt)
322+
: new bn_bwd(*pd, *inVal_, *mean_, *var_, *out, *in));
323+
pipeline.push_back(*bwdData_);
324+
}
325+
326+
} // namespace paddle

0 commit comments

Comments
 (0)