@@ -67,7 +67,13 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
67
67
mid->mutable_data <T>(ctx.GetPlace ());
68
68
69
69
const int n = ctx.Attr <int >(" n" );
70
- const float alpha = ctx.Attr <float >(" alpha" );
70
+ // MKL-DNN implements LRN in a caffe way:
71
+ // http://caffe.berkeleyvision.org/tutorial/layers/lrn.html
72
+ // Where sum of squares is divided by size of normalization window
73
+ // this is not the case for PaddlePaddle LRN.
74
+ // Hence we need to compensate for this diffrence by
75
+ // multipliing alpha by size of window(n)
76
+ const float alpha = ctx.Attr <float >(" alpha" ) * static_cast <float >(n);
71
77
const float beta = ctx.Attr <float >(" beta" );
72
78
const float k = ctx.Attr <float >(" k" );
73
79
const bool is_test = ctx.Attr <bool >(" is_test" );
@@ -78,10 +84,7 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
78
84
auto dims = paddle::framework::vectorize2int (x->dims ());
79
85
80
86
auto src_md = paddle::platform::MKLDNNMemDesc (
81
- dims, mkldnn::memory::data_type::f32 , mkldnn::memory::format::nchw);
82
-
83
- auto dst_md = paddle::platform::MKLDNNMemDesc (
84
- dims, mkldnn::memory::data_type::f32 , mkldnn::memory::format::nchw);
87
+ dims, mkldnn::memory::data_type::f32 , x->format ());
85
88
86
89
auto forward_desc = mkldnn::lrn_forward::desc{mkldnn::prop_kind::forward,
87
90
mkldnn::lrn_across_channels,
@@ -92,8 +95,6 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
92
95
k};
93
96
94
97
auto src_memory_pd = mkldnn::memory::primitive_desc{src_md, mkldnn_engine};
95
- auto dst_memory = mkldnn::memory{{dst_md, mkldnn_engine},
96
- static_cast <void *>(output_data)};
97
98
98
99
if (!is_test) {
99
100
const std::string key = ctx.op ().Output (" Out" );
@@ -110,20 +111,30 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
110
111
src_memory->set_data_handle (
111
112
static_cast <void *>(const_cast <T*>(input_data)));
112
113
114
+ auto dst_memory = mkldnn::memory (forward_pd->dst_primitive_desc (),
115
+ static_cast <void *>(output_data));
113
116
auto workspace_memory = insert_to_context<mkldnn::memory>(
114
117
key_workspace_memory, dev_ctx,
115
118
forward_pd->workspace_primitive_desc ());
116
119
117
120
run_primitive (*forward_pd, *src_memory, *workspace_memory, dst_memory);
121
+
122
+ out->set_layout (framework::DataLayout::kMKLDNN );
123
+ out->set_format (platform::GetMKLDNNFormat (dst_memory));
118
124
} else {
119
125
auto forward_pd =
120
126
mkldnn::lrn_forward::primitive_desc{forward_desc, mkldnn_engine};
121
127
auto src_memory = mkldnn::memory{
122
128
src_memory_pd, static_cast <void *>(const_cast <T*>(input_data))};
123
129
auto workspace_memory =
124
130
mkldnn::memory{forward_pd.workspace_primitive_desc ()};
131
+ auto dst_memory = mkldnn::memory (forward_pd.dst_primitive_desc (),
132
+ static_cast <void *>(output_data));
125
133
126
134
run_primitive (forward_pd, src_memory, workspace_memory, dst_memory);
135
+
136
+ out->set_layout (framework::DataLayout::kMKLDNN );
137
+ out->set_format (platform::GetMKLDNNFormat (dst_memory));
127
138
}
128
139
}
129
140
};
@@ -151,7 +162,7 @@ class LRNMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
151
162
const std::string key_workspace_memory = key + " @lrn_workspace_memory" ;
152
163
153
164
const int n = ctx.Attr <int >(" n" );
154
- const float alpha = ctx.Attr <float >(" alpha" );
165
+ const float alpha = ctx.Attr <float >(" alpha" ) * static_cast < float >(n) ;
155
166
const float beta = ctx.Attr <float >(" beta" );
156
167
const float k = ctx.Attr <float >(" k" );
157
168
0 commit comments