Skip to content

Commit b8a04c2

Browse files
committed
Duplicated code was moved to common function
1 parent 3b12833 commit b8a04c2

File tree

2 files changed

+18
-19
lines changed

2 files changed

+18
-19
lines changed

paddle/fluid/operators/batch_norm_mkldnn_op.cc

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -115,12 +115,8 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
115115
if (fuse_with_relu) flags |= mkldnn::fuse_bn_relu;
116116

117117
// create mkldnn memory from input x tensor
118-
mkldnn::memory::format input_format = x->format();
119-
if (src_tz.size() == 1) {
120-
input_format = mkldnn::memory::format::x;
121-
} else if (src_tz.size() == 2) {
122-
input_format = mkldnn::memory::format::nc;
123-
}
118+
mkldnn::memory::format input_format =
119+
platform::MKLDNNFormatForSize(src_tz.size(), x->format());
124120

125121
auto src_memory = memory(
126122
{{{src_tz}, memory::data_type::f32, input_format}, mkldnn_engine},
@@ -259,23 +255,16 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
259255

260256
// create mkldnn memory from input diff_y tensor
261257

262-
mkldnn::memory::format dst_format = x->format();
263-
if (diff_dst_tz.size() == 1) {
264-
dst_format = mkldnn::memory::format::x;
265-
} else if (diff_dst_tz.size() == 2) {
266-
dst_format = mkldnn::memory::format::nc;
267-
}
258+
mkldnn::memory::format dst_format =
259+
platform::MKLDNNFormatForSize(src_tz.size(), diff_y->format());
260+
268261
auto user_diff_dst_memory = memory(
269262
{{{diff_dst_tz}, memory::data_type::f32, dst_format}, mkldnn_engine},
270263
to_void_cast(diff_y_data));
271264

272265
// create mkldnn memory from input x tensor
273-
mkldnn::memory::format input_format = x->format();
274-
if (src_tz.size() == 1) {
275-
input_format = mkldnn::memory::format::x;
276-
} else if (src_tz.size() == 2) {
277-
input_format = mkldnn::memory::format::nc;
278-
}
266+
mkldnn::memory::format input_format =
267+
platform::MKLDNNFormatForSize(src_tz.size(), x->format());
279268

280269
auto src_memory = memory(
281270
{{{src_tz}, memory::data_type::f32, input_format}, mkldnn_engine},

paddle/fluid/platform/mkldnn_helper.h

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ class MKLDNNHandler {
228228
return dstr;
229229
};
230230
return dims2str(operand_dims) + suffix;
231-
};
231+
}
232232

233233
protected:
234234
const MKLDNNDeviceContext& dev_ctx_;
@@ -237,5 +237,15 @@ class MKLDNNHandler {
237237
bool is_reusing_;
238238
};
239239

240+
inline mkldnn::memory::format MKLDNNFormatForSize(
241+
size_t dims_size, mkldnn::memory::format data_format) {
242+
if (dims_size == 1) {
243+
return mkldnn::memory::format::x;
244+
} else if (dims_size == 2) {
245+
return mkldnn::memory::format::nc;
246+
}
247+
return data_format;
248+
}
249+
240250
} // namespace platform
241251
} // namespace paddle

0 commit comments

Comments
 (0)