Skip to content

Commit b68f2d2

Browse files
authored
Merge pull request #5049 from tensor-tang/mkldnn_bn
enable mkldnn_batch_norm
2 parents 97fcaef + 5ba1e1e commit b68f2d2

File tree

10 files changed

+603
-21
lines changed

10 files changed

+603
-21
lines changed
Lines changed: 318 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,318 @@
1+
/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "MKLDNNBatchNormLayer.h"
16+
17+
using namespace mkldnn; // NOLINT
18+
typedef memory::format format;
19+
20+
namespace paddle {
21+
22+
REGISTER_LAYER(mkldnn_batch_norm, MKLDNNBatchNormLayer);
23+
24+
const real MKLDNNBatchNormLayer::EPS = 1E-5;
25+
26+
bool MKLDNNBatchNormLayer::init(const LayerMap& layerMap,
27+
const ParameterMap& parameterMap) {
28+
if (!MKLDNNLayer::init(layerMap, parameterMap)) {
29+
return false;
30+
}
31+
32+
// first one is input layer
33+
// the other two are created in config_parser.py saving moving mean and var
34+
CHECK_EQ(inputLayers_.size(), 3U);
35+
CHECK_EQ(inputLayers_.size(), parameters_.size());
36+
CHECK_EQ(inputLayers_.size(), size_t(config_.inputs_size()));
37+
38+
const ImageConfig& conf = config_.inputs(0).image_conf();
39+
ic_ = conf.channels();
40+
ih_ = inputLayers_[0]->getOutput().getFrameHeight();
41+
iw_ = inputLayers_[0]->getOutput().getFrameWidth();
42+
if (iw_ == 0 && ih_ == 0) {
43+
iw_ = conf.img_size();
44+
ih_ = conf.has_img_size_y() ? conf.img_size_y() : conf.img_size();
45+
}
46+
oc_ = ic_;
47+
oh_ = ih_;
48+
ow_ = iw_;
49+
if (config_.has_use_global_stats()) {
50+
useGlobalStats_ = config_.use_global_stats();
51+
}
52+
movingAvgFraction_ = config_.moving_average_fraction();
53+
VLOG(MKLDNN_BASE) << "--- " << (useGlobalStats_ ? "use" : "do not use")
54+
<< " --- global stats";
55+
VLOG(MKLDNN_BASE) << "Moving average fraction: " << movingAvgFraction_;
56+
57+
initWeight();
58+
movingMean_.reset(new Weight(oc_, 1, parameters_[1], 0));
59+
movingVar_.reset(new Weight(oc_, 1, parameters_[2], 0));
60+
return true;
61+
}
62+
63+
void MKLDNNBatchNormLayer::initWeight() {
64+
weight_.reset(new Weight(1, oc_, parameters_[0]));
65+
if (biasParameter_.get() != NULL) {
66+
biases_ = std::unique_ptr<Weight>(new Weight(1, oc_, biasParameter_));
67+
}
68+
CHECK_EQ(weight_ != nullptr, biases_ != nullptr)
69+
<< "only support have both weight and bias, or neither";
70+
if (weight_ && weight_->getW()) {
71+
CHECK(biases_ && biases_->getW());
72+
valueScaleShift_ = Matrix::create(2, oc_, false, false);
73+
valueScaleShift_->zeroMem();
74+
VectorPtr scale(new CpuVector(oc_, valueScaleShift_->getMemoryHandle(), 0));
75+
VectorPtr shift(
76+
new CpuVector(oc_, valueScaleShift_->getMemoryHandle(), oc_));
77+
const VectorPtr& wgt = parameters_[0]->getBuf(PARAMETER_VALUE);
78+
const VectorPtr& bias = biasParameter_->getBuf(PARAMETER_VALUE);
79+
scale->copyFrom(*wgt);
80+
shift->copyFrom(*bias);
81+
wgt->setData(valueScaleShift_->getData());
82+
bias->setData(valueScaleShift_->getData() + oc_);
83+
}
84+
if (weight_ && weight_->getWGrad()) {
85+
CHECK(biases_ && biases_->getWGrad());
86+
gradScaleShift_ = Matrix::create(2, oc_, false, false);
87+
gradScaleShift_->zeroMem();
88+
const VectorPtr& wgt = parameters_[0]->getBuf(PARAMETER_GRADIENT);
89+
const VectorPtr& bias = biasParameter_->getBuf(PARAMETER_GRADIENT);
90+
wgt->setData(gradScaleShift_->getData());
91+
bias->setData(gradScaleShift_->getData() + oc_);
92+
}
93+
}
94+
95+
void MKLDNNBatchNormLayer::convertWeightsFromPaddle() {
96+
if (hasInitedWgt_) {
97+
return;
98+
}
99+
// prepare mean and var if necessary
100+
if (useGlobalStats_) {
101+
CHECK(mean_);
102+
CHECK(var_);
103+
mean_->copyFrom(*(movingMean_->getW()));
104+
var_->copyFrom(*(movingVar_->getW()));
105+
}
106+
hasInitedWgt_ = true;
107+
}
108+
109+
void MKLDNNBatchNormLayer::calMovingMeanAndVar() {
110+
// calculating and saving moving mean and variance
111+
CHECK_EQ(useGlobalStats_, false);
112+
movingMean_->getW()->add(
113+
*mean_, movingAvgFraction_, 1.0 - movingAvgFraction_);
114+
// here var is v^2
115+
movingVar_->getW()->add(*var_, movingAvgFraction_, 1.0 - movingAvgFraction_);
116+
}
117+
118+
void MKLDNNBatchNormLayer::reshape(
119+
int& bs, int& ic, int& ih, int& iw, int oc, int& oh, int& ow) {
120+
reshapeInput(bs, ih, iw);
121+
oh = ih;
122+
ow = ow;
123+
// ic_ and oc can not be changed
124+
CHECK_EQ(inputElemenCnt_ / bs / ih / iw, (size_t)ic)
125+
<< "Input channel can not be changed";
126+
reshapeOutput(oh, ow);
127+
resizeOutput(bs, oc * oh * ow);
128+
printSizeInfo();
129+
}
130+
131+
void MKLDNNBatchNormLayer::resetFwd(std::vector<primitive>& pipeline,
132+
MKLDNNMatrixPtr& in,
133+
MKLDNNMatrixPtr& wgt,
134+
MKLDNNMatrixPtr& bias,
135+
MKLDNNMatrixPtr& out) {
136+
// In training phase, it will always calculate mean and var,
137+
// so useGlobalStats must be false.
138+
// In scoring phase, it depends on useGlobalStats choice.
139+
if (passType_ != PASS_TEST && useGlobalStats_ == true) {
140+
LOG(WARNING) << "use_global_stats is invalid setting in training phase";
141+
useGlobalStats_ = false;
142+
}
143+
144+
resetFwdBuffers(in, wgt, out);
145+
146+
resetFwdPD(fwdPD_, in, wgt, out);
147+
148+
resetFwdPipeline(pipeline, fwdPD_, in, wgt, out);
149+
}
150+
151+
void MKLDNNBatchNormLayer::resetBwd(std::vector<primitive>& pipeline,
152+
MKLDNNMatrixPtr& in,
153+
MKLDNNMatrixPtr& wgt,
154+
MKLDNNMatrixPtr& bias,
155+
MKLDNNMatrixPtr& out) {
156+
std::shared_ptr<bn_bwd::primitive_desc> pd;
157+
158+
resetBwdBuffers(in, wgt, out);
159+
160+
resetBwdPD(pd, in, wgt, out);
161+
162+
resetBwdPipeline(pipeline, pd, in, wgt, out);
163+
}
164+
165+
void MKLDNNBatchNormLayer::forward(PassType passType) {
166+
MKLDNNLayer::forward(passType);
167+
168+
// calculate and save moving mean and variance
169+
if (passType_ != PASS_TEST) {
170+
calMovingMeanAndVar();
171+
}
172+
}
173+
174+
void MKLDNNBatchNormLayer::updateWeights(const UpdateCallback& callback) {
175+
weight_->getParameterPtr()->incUpdate(callback);
176+
if (biases_ && biases_->getWGrad()) {
177+
biases_->getParameterPtr()->incUpdate(callback);
178+
}
179+
}
180+
181+
void MKLDNNBatchNormLayer::resetFwdBuffers(MKLDNNMatrixPtr& in,
182+
MKLDNNMatrixPtr& wgt,
183+
MKLDNNMatrixPtr& out) {
184+
resetInValue(in);
185+
186+
memory::dims outDims = memory::dims{bs_, oc_, oh_, ow_};
187+
CHECK(in);
188+
auto outPD =
189+
MKLDNNMatrix::createPrimitiveDesc(outDims, in->getFormat(), engine_);
190+
resetOutValue(out, outPD);
191+
192+
if (valueScaleShift_) {
193+
auto pd = MKLDNNMatrix::createPrimitiveDesc({2, oc_}, format::nc, engine_);
194+
resetWithMatrix(wgt, valueScaleShift_, pd);
195+
}
196+
if (passType_ != PASS_TEST || useGlobalStats_) {
197+
auto pd = MKLDNNMatrix::createPrimitiveDesc({oc_}, format::x, engine_);
198+
mean_ = MKLDNNMatrix::create(pd);
199+
var_ = MKLDNNMatrix::create(pd);
200+
}
201+
}
202+
203+
void MKLDNNBatchNormLayer::resetFwdPD(
204+
std::shared_ptr<bn_fwd::primitive_desc>& pd,
205+
MKLDNNMatrixPtr in,
206+
MKLDNNMatrixPtr wgt,
207+
MKLDNNMatrixPtr out) {
208+
flags_ = 0u;
209+
prop_kind pk = passType_ == PASS_TEST ? prop_kind::forward_scoring
210+
: prop_kind::forward_training;
211+
if (useGlobalStats_) {
212+
flags_ = (flags_ | batch_normalization_flag::use_global_stats);
213+
}
214+
if (wgt) {
215+
flags_ = (flags_ | batch_normalization_flag::use_scale_shift);
216+
}
217+
auto fwdDesc = bn_fwd::desc(pk, in->getMemoryDesc(), EPS, flags_);
218+
pd.reset(new bn_fwd::primitive_desc(fwdDesc, engine_));
219+
// TODO(TJ): use check macro
220+
CHECK(out);
221+
CHECK(out->getPrimitiveDesc() == pd->dst_primitive_desc());
222+
if (wgt) {
223+
CHECK(wgt->getPrimitiveDesc() == pd->weights_primitive_desc());
224+
}
225+
if (passType_ != PASS_TEST || useGlobalStats_) {
226+
CHECK(mean_);
227+
CHECK(mean_->getPrimitiveDesc() == pd->mean_primitive_desc());
228+
CHECK(var_);
229+
CHECK(var_->getPrimitiveDesc() == pd->variance_primitive_desc());
230+
}
231+
}
232+
233+
void MKLDNNBatchNormLayer::resetFwdPipeline(
234+
std::vector<primitive>& pipeline,
235+
std::shared_ptr<bn_fwd::primitive_desc>& pd,
236+
MKLDNNMatrixPtr& in,
237+
MKLDNNMatrixPtr& wgt,
238+
MKLDNNMatrixPtr& out) {
239+
if (passType_ == PASS_TEST) {
240+
if (useGlobalStats_) {
241+
fwd_.reset(wgt != nullptr ? new bn_fwd(*pd,
242+
*in,
243+
(const primitive::at)(*mean_),
244+
(const primitive::at)(*var_),
245+
*wgt,
246+
*out)
247+
: new bn_fwd(*pd,
248+
*in,
249+
(const primitive::at)(*mean_),
250+
(const primitive::at)(*var_),
251+
*out));
252+
} else {
253+
fwd_.reset(wgt != nullptr ? new bn_fwd(*pd, *in, *wgt, *out)
254+
: new bn_fwd(*pd, *in, *out));
255+
}
256+
} else {
257+
CHECK_EQ(useGlobalStats_, false)
258+
<< "useGlobalStats should be false in training";
259+
fwd_.reset(wgt != nullptr ? new bn_fwd(*pd, *in, *wgt, *out, *mean_, *var_)
260+
: new bn_fwd(*pd, *in, *out, *mean_, *var_));
261+
}
262+
pipeline.push_back(*fwd_);
263+
}
264+
265+
void MKLDNNBatchNormLayer::resetBwdBuffers(MKLDNNMatrixPtr& in,
266+
MKLDNNMatrixPtr& wgt,
267+
MKLDNNMatrixPtr& out) {
268+
CHECK(inVal_ && outVal_);
269+
resetOutGrad(out, outVal_->getPrimitiveDesc());
270+
resetInGrad(in, inVal_->getPrimitiveDesc());
271+
if (gradScaleShift_) {
272+
CHECK(wgtVal_);
273+
resetWithMatrix(wgt, gradScaleShift_, wgtVal_->getPrimitiveDesc());
274+
}
275+
}
276+
277+
void MKLDNNBatchNormLayer::resetBwdPD(
278+
std::shared_ptr<bn_bwd::primitive_desc>& pd,
279+
MKLDNNMatrixPtr& in,
280+
MKLDNNMatrixPtr& wgt,
281+
MKLDNNMatrixPtr& out) {
282+
pd = nullptr;
283+
if (in == nullptr) {
284+
return;
285+
}
286+
CHECK(out);
287+
CHECK(out->getPrimitiveDesc() == in->getPrimitiveDesc());
288+
auto md = in->getMemoryDesc();
289+
auto bwdDesc = bn_bwd::desc(prop_kind::backward, md, md, EPS, flags_);
290+
pd.reset(new bn_bwd::primitive_desc(bwdDesc, engine_, *fwdPD_));
291+
// TODO(TJ): use check macro
292+
CHECK(wgt);
293+
CHECK(wgt->getPrimitiveDesc() == pd->diff_weights_primitive_desc());
294+
CHECK(pd->weights_primitive_desc() == fwdPD_->weights_primitive_desc());
295+
CHECK(mean_);
296+
CHECK(mean_->getPrimitiveDesc() == pd->mean_primitive_desc());
297+
CHECK(var_);
298+
CHECK(var_->getPrimitiveDesc() == pd->variance_primitive_desc());
299+
}
300+
301+
void MKLDNNBatchNormLayer::resetBwdPipeline(
302+
std::vector<primitive>& pipeline,
303+
std::shared_ptr<bn_bwd::primitive_desc>& pd,
304+
MKLDNNMatrixPtr& in,
305+
MKLDNNMatrixPtr& wgt,
306+
MKLDNNMatrixPtr& out) {
307+
if (pd == nullptr) {
308+
return;
309+
}
310+
CHECK(inVal_);
311+
bwdData_.reset(
312+
wgt && wgtVal_
313+
? new bn_bwd(*pd, *inVal_, *mean_, *var_, *out, *wgtVal_, *in, *wgt)
314+
: new bn_bwd(*pd, *inVal_, *mean_, *var_, *out, *in));
315+
pipeline.push_back(*bwdData_);
316+
}
317+
318+
} // namespace paddle

0 commit comments

Comments
 (0)