Skip to content

Commit b037896

Browse files
authored
Merge pull request #11666 from mozga-intel/mozga-intel/Batch_norm_support_other_type
The mkldnn batch norm supports other data format
2 parents fff6fa0 + 61c54db commit b037896

File tree

5 files changed

+38
-21
lines changed

5 files changed

+38
-21
lines changed

paddle/fluid/framework/data_layout_transform.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,9 +147,9 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var,
147147
"Input tensor type is not supported: ", in.type().name());
148148
memory::data_type out_type = in_type;
149149

150-
auto in_format = MKLDNNFormatForSize(in_tz.size(), in.format());
150+
auto in_format = platform::MKLDNNFormatForSize(in_tz.size(), in.format());
151151
auto out_format =
152-
MKLDNNFormatForSize(in_tz.size(), ToMKLDNNFormat(out_layout));
152+
platform::MKLDNNFormatForSize(in_tz.size(), ToMKLDNNFormat(out_layout));
153153

154154
void* in_data = GetDataFromTensor(in, in_type);
155155

paddle/fluid/framework/data_layout_transform.h

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -62,12 +62,6 @@ inline MKLDNNDataType ToMKLDNNDataType(const std::type_index type) {
6262
return MKLDNNDataType::data_undef;
6363
}
6464

65-
inline MKLDNNFormat MKLDNNFormatForSize(size_t dims_size,
66-
MKLDNNFormat default_format) {
67-
return (dims_size == 1
68-
? mkldnn::memory::format::x
69-
: dims_size == 2 ? mkldnn::memory::format::nc : default_format);
70-
}
7165
#endif
7266

7367
void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var,

paddle/fluid/framework/data_transform.cc

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@ limitations under the License. */
1818
#include "paddle/fluid/framework/data_layout_transform.h"
1919
#include "paddle/fluid/framework/data_type_transform.h"
2020

21+
#ifdef PADDLE_WITH_MKLDNN
22+
#include "paddle/fluid/platform/mkldnn_helper.h"
23+
#endif
24+
2125
namespace paddle {
2226
namespace framework {
2327

@@ -48,8 +52,8 @@ void TransformData(const OpKernelType &expected_kernel_type,
4852
// Case1 - transform from Non-MKLDNN OPKernel to MKLDNN OPKernel
4953
// Just set layout/format. No real transform occur
5054

51-
auto out_format =
52-
MKLDNNFormatForSize(in.dims().size(), ToMKLDNNFormat(lin));
55+
auto out_format = platform::MKLDNNFormatForSize(in.dims().size(),
56+
ToMKLDNNFormat(lin));
5357

5458
out.ShareDataWith(input_tensor);
5559
out.set_layout(DataLayout::kMKLDNN);

paddle/fluid/operators/batch_norm_mkldnn_op.cc

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -115,9 +115,12 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
115115
if (fuse_with_relu) flags |= mkldnn::fuse_bn_relu;
116116

117117
// create mkldnn memory from input x tensor
118-
auto src_memory =
119-
memory({{{src_tz}, memory::data_type::f32, x->format()}, mkldnn_engine},
120-
to_void_cast(x_data));
118+
mkldnn::memory::format input_format =
119+
platform::MKLDNNFormatForSize(src_tz.size(), x->format());
120+
121+
auto src_memory = memory(
122+
{{{src_tz}, memory::data_type::f32, input_format}, mkldnn_engine},
123+
to_void_cast(x_data));
121124

122125
// create primitive descriptor for batch norm forward
123126
using bn_fwd_types = bn_type_traits<mkldnn::batch_normalization_forward>;
@@ -251,15 +254,21 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
251254
using bn_bwd_types = bn_type_traits<mkldnn::batch_normalization_backward>;
252255

253256
// create mkldnn memory from input diff_y tensor
254-
auto user_diff_dst_memory =
255-
memory({{{diff_dst_tz}, memory::data_type::f32, diff_y->format()},
256-
mkldnn_engine},
257-
to_void_cast(diff_y_data));
257+
258+
mkldnn::memory::format dst_format =
259+
platform::MKLDNNFormatForSize(src_tz.size(), diff_y->format());
260+
261+
auto user_diff_dst_memory = memory(
262+
{{{diff_dst_tz}, memory::data_type::f32, dst_format}, mkldnn_engine},
263+
to_void_cast(diff_y_data));
258264

259265
// create mkldnn memory from input x tensor
260-
auto src_memory =
261-
memory({{{src_tz}, memory::data_type::f32, x->format()}, mkldnn_engine},
262-
to_void_cast(x_data));
266+
mkldnn::memory::format input_format =
267+
platform::MKLDNNFormatForSize(src_tz.size(), x->format());
268+
269+
auto src_memory = memory(
270+
{{{src_tz}, memory::data_type::f32, input_format}, mkldnn_engine},
271+
to_void_cast(x_data));
263272

264273
// for diff_dst, try to use same format as dst in forward pass
265274
auto diff_dst_pd = batch_norm_fwd_pd.get()->dst_primitive_desc();

paddle/fluid/platform/mkldnn_helper.h

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ class MKLDNNHandler {
228228
return dstr;
229229
};
230230
return dims2str(operand_dims) + suffix;
231-
};
231+
}
232232

233233
protected:
234234
const MKLDNNDeviceContext& dev_ctx_;
@@ -237,5 +237,15 @@ class MKLDNNHandler {
237237
bool is_reusing_;
238238
};
239239

240+
inline mkldnn::memory::format MKLDNNFormatForSize(
241+
size_t dims_size, mkldnn::memory::format data_format) {
242+
if (dims_size == 1) {
243+
return mkldnn::memory::format::x;
244+
} else if (dims_size == 2) {
245+
return mkldnn::memory::format::nc;
246+
}
247+
return data_format;
248+
}
249+
240250
} // namespace platform
241251
} // namespace paddle

0 commit comments

Comments
 (0)