Skip to content

Commit b919190

Browse files
authored
Merge pull request #15531 from jczaja/prv-googlenet-fix
Performance and functional fixes to LRN
2 parents f825158 + 5885c5c commit b919190

File tree

2 files changed

+24
-8
lines changed

2 files changed

+24
-8
lines changed

paddle/fluid/inference/tests/api/CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ else()
5454
message(WARNING "These tests has been disabled in OSX or WITH_MKL=OFF before being fixed: \n test_analyzer_seq_pool1")
5555
endif()
5656

57+
5758
# RNN2
5859
set(RNN2_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/rnn2")
5960
download_model_and_data(${RNN2_INSTALL_DIR} "rnn2_model.tar.gz" "rnn2_data.txt.tar.gz")
@@ -115,6 +116,10 @@ if (NOT EXISTS ${MOBILENET_INSTALL_DIR})
115116
endif()
116117
inference_analysis_api_test_with_refer_result(test_analyzer_mobilenet_transpose ${MOBILENET_INSTALL_DIR} analyzer_vis_tester.cc SERIAL)
117118

119+
# googlenet
120+
inference_analysis_api_test_with_fake_data(test_analyzer_googlenet
121+
"${INFERENCE_DEMO_INSTALL_DIR}/googlenet" analyzer_resnet50_tester.cc "googlenet.tar.gz" SERIAL)
122+
118123
# resnet50
119124
inference_analysis_api_test_with_fake_data(test_analyzer_resnet50
120125
"${INFERENCE_DEMO_INSTALL_DIR}/resnet50" analyzer_resnet50_tester.cc "resnet50_model.tar.gz" SERIAL)

paddle/fluid/operators/lrn_mkldnn_op.cc

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,13 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
6767
mid->mutable_data<T>(ctx.GetPlace());
6868

6969
const int n = ctx.Attr<int>("n");
70-
const float alpha = ctx.Attr<float>("alpha");
70+
// MKL-DNN implements LRN in a caffe way:
71+
// http://caffe.berkeleyvision.org/tutorial/layers/lrn.html
72+
// Where sum of squares is divided by size of normalization window
73+
// this is not the case for PaddlePaddle LRN.
74+
// Hence we need to compensate for this diffrence by
75+
// multipliing alpha by size of window(n)
76+
const float alpha = ctx.Attr<float>("alpha") * static_cast<float>(n);
7177
const float beta = ctx.Attr<float>("beta");
7278
const float k = ctx.Attr<float>("k");
7379
const bool is_test = ctx.Attr<bool>("is_test");
@@ -78,10 +84,7 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
7884
auto dims = paddle::framework::vectorize2int(x->dims());
7985

8086
auto src_md = paddle::platform::MKLDNNMemDesc(
81-
dims, mkldnn::memory::data_type::f32, mkldnn::memory::format::nchw);
82-
83-
auto dst_md = paddle::platform::MKLDNNMemDesc(
84-
dims, mkldnn::memory::data_type::f32, mkldnn::memory::format::nchw);
87+
dims, mkldnn::memory::data_type::f32, x->format());
8588

8689
auto forward_desc = mkldnn::lrn_forward::desc{mkldnn::prop_kind::forward,
8790
mkldnn::lrn_across_channels,
@@ -92,8 +95,6 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
9295
k};
9396

9497
auto src_memory_pd = mkldnn::memory::primitive_desc{src_md, mkldnn_engine};
95-
auto dst_memory = mkldnn::memory{{dst_md, mkldnn_engine},
96-
static_cast<void*>(output_data)};
9798

9899
if (!is_test) {
99100
const std::string key = ctx.op().Output("Out");
@@ -110,20 +111,30 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
110111
src_memory->set_data_handle(
111112
static_cast<void*>(const_cast<T*>(input_data)));
112113

114+
auto dst_memory = mkldnn::memory(forward_pd->dst_primitive_desc(),
115+
static_cast<void*>(output_data));
113116
auto workspace_memory = insert_to_context<mkldnn::memory>(
114117
key_workspace_memory, dev_ctx,
115118
forward_pd->workspace_primitive_desc());
116119

117120
run_primitive(*forward_pd, *src_memory, *workspace_memory, dst_memory);
121+
122+
out->set_layout(framework::DataLayout::kMKLDNN);
123+
out->set_format(platform::GetMKLDNNFormat(dst_memory));
118124
} else {
119125
auto forward_pd =
120126
mkldnn::lrn_forward::primitive_desc{forward_desc, mkldnn_engine};
121127
auto src_memory = mkldnn::memory{
122128
src_memory_pd, static_cast<void*>(const_cast<T*>(input_data))};
123129
auto workspace_memory =
124130
mkldnn::memory{forward_pd.workspace_primitive_desc()};
131+
auto dst_memory = mkldnn::memory(forward_pd.dst_primitive_desc(),
132+
static_cast<void*>(output_data));
125133

126134
run_primitive(forward_pd, src_memory, workspace_memory, dst_memory);
135+
136+
out->set_layout(framework::DataLayout::kMKLDNN);
137+
out->set_format(platform::GetMKLDNNFormat(dst_memory));
127138
}
128139
}
129140
};
@@ -151,7 +162,7 @@ class LRNMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
151162
const std::string key_workspace_memory = key + "@lrn_workspace_memory";
152163

153164
const int n = ctx.Attr<int>("n");
154-
const float alpha = ctx.Attr<float>("alpha");
165+
const float alpha = ctx.Attr<float>("alpha") * static_cast<float>(n);
155166
const float beta = ctx.Attr<float>("beta");
156167
const float k = ctx.Attr<float>("k");
157168

0 commit comments

Comments
 (0)