@@ -115,12 +115,8 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
115
115
if (fuse_with_relu) flags |= mkldnn::fuse_bn_relu;
116
116
117
117
// create mkldnn memory from input x tensor
118
- mkldnn::memory::format input_format = x->format ();
119
- if (src_tz.size () == 1 ) {
120
- input_format = mkldnn::memory::format::x;
121
- } else if (src_tz.size () == 2 ) {
122
- input_format = mkldnn::memory::format::nc;
123
- }
118
+ mkldnn::memory::format input_format =
119
+ platform::MKLDNNFormatForSize (src_tz.size (), x->format ());
124
120
125
121
auto src_memory = memory (
126
122
{{{src_tz}, memory::data_type::f32 , input_format}, mkldnn_engine},
@@ -259,23 +255,16 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
259
255
260
256
// create mkldnn memory from input diff_y tensor
261
257
262
- mkldnn::memory::format dst_format = x->format ();
263
- if (diff_dst_tz.size () == 1 ) {
264
- dst_format = mkldnn::memory::format::x;
265
- } else if (diff_dst_tz.size () == 2 ) {
266
- dst_format = mkldnn::memory::format::nc;
267
- }
258
+ mkldnn::memory::format dst_format =
259
+ platform::MKLDNNFormatForSize (src_tz.size (), diff_y->format ());
260
+
268
261
auto user_diff_dst_memory = memory (
269
262
{{{diff_dst_tz}, memory::data_type::f32 , dst_format}, mkldnn_engine},
270
263
to_void_cast (diff_y_data));
271
264
272
265
// create mkldnn memory from input x tensor
273
- mkldnn::memory::format input_format = x->format ();
274
- if (src_tz.size () == 1 ) {
275
- input_format = mkldnn::memory::format::x;
276
- } else if (src_tz.size () == 2 ) {
277
- input_format = mkldnn::memory::format::nc;
278
- }
266
+ mkldnn::memory::format input_format =
267
+ platform::MKLDNNFormatForSize (src_tz.size (), x->format ());
279
268
280
269
auto src_memory = memory (
281
270
{{{src_tz}, memory::data_type::f32 , input_format}, mkldnn_engine},
0 commit comments