Skip to content

Commit 3b12833

Browse files
committed
The mkldnn batch norm supports other data format
1 parent ae0d0c4 commit 3b12833

File tree

1 file changed

+30
-10
lines changed

1 file changed

+30
-10
lines changed

paddle/fluid/operators/batch_norm_mkldnn_op.cc

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -115,9 +115,16 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
115115
if (fuse_with_relu) flags |= mkldnn::fuse_bn_relu;
116116

117117
// create mkldnn memory from input x tensor
118-
auto src_memory =
119-
memory({{{src_tz}, memory::data_type::f32, x->format()}, mkldnn_engine},
120-
to_void_cast(x_data));
118+
mkldnn::memory::format input_format = x->format();
119+
if (src_tz.size() == 1) {
120+
input_format = mkldnn::memory::format::x;
121+
} else if (src_tz.size() == 2) {
122+
input_format = mkldnn::memory::format::nc;
123+
}
124+
125+
auto src_memory = memory(
126+
{{{src_tz}, memory::data_type::f32, input_format}, mkldnn_engine},
127+
to_void_cast(x_data));
121128

122129
// create primitive descriptor for batch norm forward
123130
using bn_fwd_types = bn_type_traits<mkldnn::batch_normalization_forward>;
@@ -251,15 +258,28 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
251258
using bn_bwd_types = bn_type_traits<mkldnn::batch_normalization_backward>;
252259

253260
// create mkldnn memory from input diff_y tensor
254-
auto user_diff_dst_memory =
255-
memory({{{diff_dst_tz}, memory::data_type::f32, diff_y->format()},
256-
mkldnn_engine},
257-
to_void_cast(diff_y_data));
261+
262+
mkldnn::memory::format dst_format = x->format();
263+
if (diff_dst_tz.size() == 1) {
264+
dst_format = mkldnn::memory::format::x;
265+
} else if (diff_dst_tz.size() == 2) {
266+
dst_format = mkldnn::memory::format::nc;
267+
}
268+
auto user_diff_dst_memory = memory(
269+
{{{diff_dst_tz}, memory::data_type::f32, dst_format}, mkldnn_engine},
270+
to_void_cast(diff_y_data));
258271

259272
// create mkldnn memory from input x tensor
260-
auto src_memory =
261-
memory({{{src_tz}, memory::data_type::f32, x->format()}, mkldnn_engine},
262-
to_void_cast(x_data));
273+
mkldnn::memory::format input_format = x->format();
274+
if (src_tz.size() == 1) {
275+
input_format = mkldnn::memory::format::x;
276+
} else if (src_tz.size() == 2) {
277+
input_format = mkldnn::memory::format::nc;
278+
}
279+
280+
auto src_memory = memory(
281+
{{{src_tz}, memory::data_type::f32, input_format}, mkldnn_engine},
282+
to_void_cast(x_data));
263283

264284
// for diff_dst, try to use same format as dst in forward pass
265285
auto diff_dst_pd = batch_norm_fwd_pd.get()->dst_primitive_desc();

0 commit comments

Comments
 (0)