@@ -115,9 +115,16 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
115
115
if (fuse_with_relu) flags |= mkldnn::fuse_bn_relu;
116
116
117
117
// create mkldnn memory from input x tensor
118
- auto src_memory =
119
- memory ({{{src_tz}, memory::data_type::f32 , x->format ()}, mkldnn_engine},
120
- to_void_cast (x_data));
118
+ mkldnn::memory::format input_format = x->format ();
119
+ if (src_tz.size () == 1 ) {
120
+ input_format = mkldnn::memory::format::x;
121
+ } else if (src_tz.size () == 2 ) {
122
+ input_format = mkldnn::memory::format::nc;
123
+ }
124
+
125
+ auto src_memory = memory (
126
+ {{{src_tz}, memory::data_type::f32 , input_format}, mkldnn_engine},
127
+ to_void_cast (x_data));
121
128
122
129
// create primitive descriptor for batch norm forward
123
130
using bn_fwd_types = bn_type_traits<mkldnn::batch_normalization_forward>;
@@ -251,15 +258,28 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
251
258
using bn_bwd_types = bn_type_traits<mkldnn::batch_normalization_backward>;
252
259
253
260
// create mkldnn memory from input diff_y tensor
254
- auto user_diff_dst_memory =
255
- memory ({{{diff_dst_tz}, memory::data_type::f32 , diff_y->format ()},
256
- mkldnn_engine},
257
- to_void_cast (diff_y_data));
261
+
262
+ mkldnn::memory::format dst_format = x->format ();
263
+ if (diff_dst_tz.size () == 1 ) {
264
+ dst_format = mkldnn::memory::format::x;
265
+ } else if (diff_dst_tz.size () == 2 ) {
266
+ dst_format = mkldnn::memory::format::nc;
267
+ }
268
+ auto user_diff_dst_memory = memory (
269
+ {{{diff_dst_tz}, memory::data_type::f32 , dst_format}, mkldnn_engine},
270
+ to_void_cast (diff_y_data));
258
271
259
272
// create mkldnn memory from input x tensor
260
- auto src_memory =
261
- memory ({{{src_tz}, memory::data_type::f32 , x->format ()}, mkldnn_engine},
262
- to_void_cast (x_data));
273
+ mkldnn::memory::format input_format = x->format ();
274
+ if (src_tz.size () == 1 ) {
275
+ input_format = mkldnn::memory::format::x;
276
+ } else if (src_tz.size () == 2 ) {
277
+ input_format = mkldnn::memory::format::nc;
278
+ }
279
+
280
+ auto src_memory = memory (
281
+ {{{src_tz}, memory::data_type::f32 , input_format}, mkldnn_engine},
282
+ to_void_cast (x_data));
263
283
264
284
// for diff_dst, try to use same format as dst in forward pass
265
285
auto diff_dst_pd = batch_norm_fwd_pd.get ()->dst_primitive_desc ();
0 commit comments