Skip to content

Commit 4ec9de0

Browse files
authored
Merge pull request #14628 from Sand3r-/mgallus/mkldnn-elementwise_mul
EltwiseMul: Changes from previous PR
2 parents fb26822 + 9455be0 commit 4ec9de0

File tree

2 files changed

+25
-23
lines changed

2 files changed

+25
-23
lines changed

paddle/fluid/operators/elementwise/elementwise_mul_mkldnn_op.cc

Lines changed: 8 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -19,36 +19,21 @@ limitations under the License. */
1919
#include "paddle/fluid/platform/mkldnn_helper.h"
2020

2121
#include "paddle/fluid/operators/math/jit_kernel.h"
22-
#include "xbyak.h"
23-
#include "xbyak_util.h"
22+
#include "xbyak/xbyak.h"
23+
#include "xbyak/xbyak_util.h"
2424

2525
namespace paddle {
2626
namespace operators {
2727

2828
using framework::DataLayout;
2929
using mkldnn::memory;
30-
31-
static mkldnn::memory::format StringToMKLDNNFormat(std::string& format) {
32-
std::transform(format.begin(), format.end(), format.begin(), ::tolower);
33-
34-
if (!format.compare("nchw")) {
35-
return memory::format::nchw;
36-
} else if (!format.compare("nchw16c")) {
37-
return memory::format::nChw16c;
38-
} else if (!format.compare("nchw8c")) {
39-
return memory::format::nChw8c;
40-
} else if (!format.compare("nhwc")) {
41-
return memory::format::nhwc;
42-
} else {
43-
return memory::format::any;
44-
}
45-
}
30+
using platform::StringToMKLDNNFormat;
4631

4732
static void UpdateDataFormat(const framework::ExecutionContext& ctx,
4833
framework::Tensor* tensor, const char* attribute) {
4934
if (ctx.op().HasAttr(attribute)) {
5035
auto format_as_string = ctx.Attr<std::string>(attribute);
51-
auto format = StringToMKLDNNFormat(format_as_string);
36+
auto format = StringToMKLDNNFormat(&format_as_string);
5237
if (format != memory::format::any) {
5338
tensor->set_format(format);
5439
}
@@ -93,8 +78,8 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> {
9378
auto y_dims_untrimmed = y->dims();
9479
auto x_int_dims = paddle::framework::vectorize2int(x_dims);
9580

96-
UpdateDataFormat(ctx, (Tensor*)x, "x_data_format");
97-
UpdateDataFormat(ctx, (Tensor*)y, "y_data_format");
81+
UpdateDataFormat(ctx, const_cast<Tensor*>(x), "x_data_format");
82+
UpdateDataFormat(ctx, const_cast<Tensor*>(y), "y_data_format");
9883

9984
Xbyak::util::Cpu cpu;
10085
const bool is_avx512_enabled = cpu.has(Xbyak::util::Cpu::tAVX512F);
@@ -156,10 +141,10 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> {
156141
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
157142
const auto& mkldnn_engine = dev_ctx.GetEngine();
158143
if (!(is_x_nchw || is_x_nc))
159-
ReorderInput<T>((Tensor*)x, ctx.GetPlace(), mkldnn_engine,
144+
ReorderInput<T>(const_cast<Tensor*>(x), ctx.GetPlace(), mkldnn_engine,
160145
x->dims().size() == 4);
161146
if (!(is_y_nchw || is_y_nc))
162-
ReorderInput<T>((Tensor*)y, ctx.GetPlace(), mkldnn_engine,
147+
ReorderInput<T>(const_cast<Tensor*>(y), ctx.GetPlace(), mkldnn_engine,
163148
y->dims().size() == 4);
164149
}
165150

paddle/fluid/platform/mkldnn_helper.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ limitations under the License. */
1414
#pragma once
1515

1616
#include <mkldnn.h>
17+
#include <algorithm>
1718
#include <string>
1819
#include <vector>
1920
#include "paddle/fluid/framework/operator.h"
@@ -292,5 +293,21 @@ inline mkldnn::memory::format data_format_to_memory_format(
292293
}
293294
}
294295

296+
inline mkldnn::memory::format StringToMKLDNNFormat(std::string* format) {
297+
std::transform(format->begin(), format->end(), format->begin(), ::tolower);
298+
299+
if (!format->compare("nchw")) {
300+
return mkldnn::memory::format::nchw;
301+
} else if (!format->compare("nchw16c")) {
302+
return mkldnn::memory::format::nChw16c;
303+
} else if (!format->compare("nchw8c")) {
304+
return mkldnn::memory::format::nChw8c;
305+
} else if (!format->compare("nhwc")) {
306+
return mkldnn::memory::format::nhwc;
307+
} else {
308+
return mkldnn::memory::format::any;
309+
}
310+
}
311+
295312
} // namespace platform
296313
} // namespace paddle

0 commit comments

Comments
 (0)