Skip to content

Commit 343b1a9

Browse files
committed
add mkldnn_lrn unit test
1 parent 54205c9 commit 343b1a9

File tree

1 file changed

+45
-0
lines changed

1 file changed

+45
-0
lines changed

paddle/gserver/tests/test_MKLDNN.cpp

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,51 @@ TEST(MKLDNNLayer, BatchNormLayer) {
272272
testBatchNormLayer({4, 16, 8, 10});
273273
}
274274

275+
struct testLRNDesc {
276+
int bs, ic, ih, iw;
277+
float scale, pow;
278+
int localSize;
279+
};
280+
281+
void getMKLDNNLRNConfig(TestConfig& cfg, const testLRNDesc& pm) {
282+
cfg.layerConfig.set_type("mkldnn_lrn");
283+
cfg.layerConfig.set_active_type("relu");
284+
size_t layerSize = pm.ic * pm.ih * pm.iw;
285+
cfg.inputDefs.push_back({INPUT_DATA, "layer_0", layerSize, 0});
286+
LayerInputConfig* input = cfg.layerConfig.add_inputs();
287+
NormConfig* norm = input->mutable_norm_conf();
288+
norm->set_channels(pm.ic);
289+
norm->set_size(pm.localSize);
290+
norm->set_scale(pm.scale);
291+
norm->set_pow(pm.pow);
292+
norm->set_blocked(0);
293+
norm->set_img_size(pm.iw);
294+
norm->set_img_size_y(pm.ih);
295+
norm->set_output_x(norm->img_size());
296+
norm->set_output_y(norm->img_size_y());
297+
cfg.layerConfig.set_size(layerSize);
298+
cfg.biasSize = 0;
299+
}
300+
301+
void testLRNLayer(const testLRNDesc& pm) {
302+
TestConfig dnnConfig;
303+
getMKLDNNLRNConfig(dnnConfig, pm);
304+
// mkldnn_lrn <==> norm with cmrnorm-projection type
305+
TestConfig refConfig = dnnConfig;
306+
refConfig.layerConfig.set_type("norm");
307+
LayerInputConfig* input = refConfig.layerConfig.mutable_inputs(0);
308+
NormConfig* norm = input->mutable_norm_conf();
309+
norm->set_norm_type("cmrnorm-projection");
310+
norm->set_scale(norm->scale() / norm->size());
311+
RUN_MKLDNN_TEST(dnnConfig, refConfig, pm)
312+
}
313+
314+
TEST(MKLDNNLayer, LRNLayer) {
315+
testLRNLayer({4, 10, 12, 12, 0.001f, 0.75f, 5});
316+
testLRNLayer({2, 32, 6, 6, 0.001f, 0.75f, 5});
317+
testLRNLayer({4, 16, 8, 10, 0.01f, 0.5f, 5});
318+
}
319+
275320
struct testImageDesc {
276321
int bs, ic, ih, iw;
277322
};

0 commit comments

Comments
 (0)