Skip to content

Commit ad6b531

Browse files
committed
add unit test for mkldnn_batch_norm layer
1 parent 64eaeba commit ad6b531

File tree

3 files changed

+84
-9
lines changed

3 files changed

+84
-9
lines changed

paddle/gserver/tests/MKLDNNTester.cpp

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -91,10 +91,16 @@ void MKLDNNTester::setInputImgSize() {
9191
// init randome parameters of ref, and copy to mkldnn
9292
void MKLDNNTester::randomWgtDatas() {
9393
EXPECT_EQ(parameters_[DNN].size(), parameters_[REF].size());
94+
const bool isBN = refLayer_->getType() == "batch_norm";
9495
for (size_t i = 0; i < parameters_[REF].size(); ++i) {
9596
const VectorPtr& dnnValue = parameters_[DNN][i]->getBuf(PARAMETER_VALUE);
9697
const VectorPtr& refValue = parameters_[REF][i]->getBuf(PARAMETER_VALUE);
9798
parameters_[REF][i]->randomize();
99+
if (isBN && i == 2) {
100+
// this param is moving average in batch norm, which must larger than 0
101+
real offset = fabs(refValue->getMin()) + 1.0;
102+
refValue->add(offset);
103+
}
98104
dnnValue->copyFrom(*refValue);
99105

100106
VLOG(MKLDNN_TESTS) << "Random weight " << parameters_[DNN][i]->getName();
@@ -132,8 +138,7 @@ void MKLDNNTester::checkForward() {
132138

133139
void MKLDNNTester::checkBackwardData() {
134140
VLOG(MKLDNN_TESTS) << "Check Backward Data";
135-
// TODO(TJ): uncomment me when batch norm ready
136-
// const bool isBN = dnnLayer_->getType() == "mkldnn_batch_norm";
141+
const bool isBN = refLayer_->getType() == "batch_norm";
137142
for (size_t i = 0; i < dataLayers_[DNN].size(); ++i) {
138143
const MatrixPtr& dnnDiff = dataLayers_[DNN][i]->getOutputGrad();
139144
const MatrixPtr& refDiff = dataLayers_[REF][i]->getOutputGrad();
@@ -144,11 +149,11 @@ void MKLDNNTester::checkBackwardData() {
144149

145150
double delta = compareMatrix(dnnDiff, refDiff);
146151
EXPECT_LE(fabs(delta), eps_);
147-
// TODO(TJ): uncomment me when batch norm ready
148-
// if (isBN) {
149-
// // the other two inputs in batch norm are for moving mean and var
150-
// break;
151-
// }
152+
if (isBN) {
153+
// the other two inputs in batch norm are for moving mean and var
154+
// do not have grad to compare
155+
break;
156+
}
152157
}
153158
}
154159

@@ -308,10 +313,14 @@ double MKLDNNTester::compareVector(const VectorPtr& v1, const VectorPtr& v2) {
308313
void MKLDNNTester::runOnce() {
309314
// test forward
310315
randomBotDatas();
311-
dnnLayer_->forward(PASS_TRAIN);
312-
refLayer_->forward(PASS_TRAIN);
316+
dnnLayer_->forward(passType_);
317+
refLayer_->forward(passType_);
313318
checkForward();
314319

320+
if (passType_ == PASS_TEST) {
321+
return;
322+
}
323+
315324
// test backward
316325
// simple updater
317326
UpdateCallback updateCallback = [](Parameter* para) {
@@ -343,6 +352,7 @@ void MKLDNNTester::run(const TestConfig& dnn,
343352
size_t batchSize,
344353
size_t inputImgH,
345354
size_t inputImgW,
355+
PassType passType,
346356
bool printDetails,
347357
size_t iter,
348358
float epsilon) {
@@ -361,6 +371,7 @@ void MKLDNNTester::run(const TestConfig& dnn,
361371

362372
ih_ = inputImgH;
363373
iw_ = inputImgW;
374+
passType_ = passType;
364375
log_ = printDetails;
365376
iter_ = iter;
366377
eps_ = epsilon;

paddle/gserver/tests/MKLDNNTester.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,12 +62,15 @@ class MKLDNNTester {
6262
float eps_;
6363
/// input image size, default 1
6464
size_t ih_, iw_;
65+
/// passType, PASS_TRAIN, PASS_TEST or PASS_GC (Gradient Check pass)
66+
PassType passType_;
6567

6668
public:
6769
explicit MKLDNNTester(size_t iter = 3, float epsilon = 1e-4) {
6870
iter_ = iter;
6971
eps_ = epsilon;
7072
log_ = false;
73+
passType_ = PASS_TRAIN;
7174
}
7275

7376
~MKLDNNTester() {}
@@ -78,6 +81,7 @@ class MKLDNNTester {
7881
size_t batchSize,
7982
size_t inputImgH = 1,
8083
size_t inputImgW = 1,
84+
PassType passType = PASS_TRAIN,
8185
bool printDetails = false,
8286
size_t iter = 3,
8387
float epsilon = 1e-4);

paddle/gserver/tests/test_MKLDNN.cpp

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,66 @@ TEST(MKLDNNLayer, PoolLayer) {
212212
testPoolLayer({2, 8, 56, 56, 29, 29, 3, 3, 1, 1, 2, 2});
213213
}
214214

215+
struct testBatchNormDesc {
216+
int bs;
217+
int ic;
218+
int ih, iw;
219+
};
220+
221+
static void getMKLDNNBatchNormConfig(TestConfig& cfg,
222+
const testBatchNormDesc& pm) {
223+
cfg.layerConfig.set_size(pm.ic * pm.ih * pm.iw);
224+
cfg.layerConfig.set_type("mkldnn_batch_norm");
225+
cfg.biasSize = pm.ic;
226+
cfg.inputDefs.push_back(
227+
{INPUT_DATA,
228+
"layer_0",
229+
/* size of input layer= */ size_t(pm.ic * pm.ih * pm.iw),
230+
/* size of weight= */ size_t(pm.ic)});
231+
cfg.inputDefs.push_back(
232+
{INPUT_DATA, "layer_1_moving_mean", 1, size_t(pm.ic)});
233+
cfg.inputDefs.back().isStatic = true;
234+
cfg.inputDefs.push_back({INPUT_DATA, "layer_2_moving_var", 1, size_t(pm.ic)});
235+
cfg.inputDefs.back().isStatic = true;
236+
LayerInputConfig* input = cfg.layerConfig.add_inputs();
237+
// TODO(TJ): uncomment me when refine and support comparing all zeroes vector
238+
// cfg.layerConfig.set_active_type("relu");
239+
cfg.layerConfig.add_inputs();
240+
cfg.layerConfig.add_inputs();
241+
ImageConfig* img_conf = input->mutable_image_conf();
242+
img_conf->set_channels(pm.ic);
243+
img_conf->set_img_size_y(pm.ih);
244+
img_conf->set_img_size(pm.iw);
245+
}
246+
247+
void testBatchNormLayer(const testBatchNormDesc& pm) {
248+
TestConfig dnnConfig;
249+
getMKLDNNBatchNormConfig(dnnConfig, pm);
250+
TestConfig refConfig = dnnConfig;
251+
refConfig.layerConfig.set_type("batch_norm");
252+
// for PASS_TRAIN, use_global_stats always should be false, and batchsize != 1
253+
VLOG(MKLDNN_TESTS) << "check train phase";
254+
dnnConfig.layerConfig.set_use_global_stats(false);
255+
refConfig.layerConfig.set_use_global_stats(false);
256+
MKLDNNTester tester;
257+
tester.run(dnnConfig, refConfig, pm.bs, pm.ih, pm.iw, PASS_TRAIN);
258+
// for PASS_TEST, check use_global_stats true and false, and batchsize 1
259+
VLOG(MKLDNN_TESTS) << "check test phase";
260+
for (auto useGS : {false, true}) {
261+
dnnConfig.layerConfig.set_use_global_stats(useGS);
262+
refConfig.layerConfig.set_use_global_stats(useGS);
263+
MKLDNNTester tester;
264+
for (auto bs : {pm.bs, 1}) {
265+
tester.run(dnnConfig, refConfig, bs, pm.ih, pm.iw, PASS_TEST);
266+
}
267+
}
268+
}
269+
270+
TEST(MKLDNNLayer, BatchNormLayer) {
271+
testBatchNormLayer({4, 10, 6, 6});
272+
testBatchNormLayer({16, 32, 16, 16});
273+
}
274+
215275
struct testActDesc {
216276
int bs, ic, ih, iw;
217277
};

0 commit comments

Comments
 (0)