@@ -19,36 +19,21 @@ limitations under the License. */
19
19
#include " paddle/fluid/platform/mkldnn_helper.h"
20
20
21
21
#include " paddle/fluid/operators/math/jit_kernel.h"
22
- #include " xbyak.h"
23
- #include " xbyak_util.h"
22
+ #include " xbyak/xbyak .h"
23
+ #include " xbyak/ xbyak_util.h"
24
24
25
25
namespace paddle {
26
26
namespace operators {
27
27
28
28
using framework::DataLayout;
29
29
using mkldnn::memory;
30
-
31
- static mkldnn::memory::format StringToMKLDNNFormat (std::string& format) {
32
- std::transform (format.begin (), format.end (), format.begin (), ::tolower);
33
-
34
- if (!format.compare (" nchw" )) {
35
- return memory::format::nchw;
36
- } else if (!format.compare (" nchw16c" )) {
37
- return memory::format::nChw16c;
38
- } else if (!format.compare (" nchw8c" )) {
39
- return memory::format::nChw8c;
40
- } else if (!format.compare (" nhwc" )) {
41
- return memory::format::nhwc;
42
- } else {
43
- return memory::format::any;
44
- }
45
- }
30
+ using platform::StringToMKLDNNFormat;
46
31
47
32
static void UpdateDataFormat (const framework::ExecutionContext& ctx,
48
33
framework::Tensor* tensor, const char * attribute) {
49
34
if (ctx.op ().HasAttr (attribute)) {
50
35
auto format_as_string = ctx.Attr <std::string>(attribute);
51
- auto format = StringToMKLDNNFormat (format_as_string);
36
+ auto format = StringToMKLDNNFormat (& format_as_string);
52
37
if (format != memory::format::any) {
53
38
tensor->set_format (format);
54
39
}
@@ -93,8 +78,8 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> {
93
78
auto y_dims_untrimmed = y->dims ();
94
79
auto x_int_dims = paddle::framework::vectorize2int (x_dims);
95
80
96
- UpdateDataFormat (ctx, ( Tensor*)x , " x_data_format" );
97
- UpdateDataFormat (ctx, ( Tensor*)y , " y_data_format" );
81
+ UpdateDataFormat (ctx, const_cast < Tensor*>(x) , " x_data_format" );
82
+ UpdateDataFormat (ctx, const_cast < Tensor*>(y) , " y_data_format" );
98
83
99
84
Xbyak::util::Cpu cpu;
100
85
const bool is_avx512_enabled = cpu.has (Xbyak::util::Cpu::tAVX512F);
@@ -156,10 +141,10 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> {
156
141
auto & dev_ctx = ctx.template device_context <MKLDNNDeviceContext>();
157
142
const auto & mkldnn_engine = dev_ctx.GetEngine ();
158
143
if (!(is_x_nchw || is_x_nc))
159
- ReorderInput<T>(( Tensor*)x , ctx.GetPlace (), mkldnn_engine,
144
+ ReorderInput<T>(const_cast < Tensor*>(x) , ctx.GetPlace (), mkldnn_engine,
160
145
x->dims ().size () == 4 );
161
146
if (!(is_y_nchw || is_y_nc))
162
- ReorderInput<T>(( Tensor*)y , ctx.GetPlace (), mkldnn_engine,
147
+ ReorderInput<T>(const_cast < Tensor*>(y) , ctx.GetPlace (), mkldnn_engine,
163
148
y->dims ().size () == 4 );
164
149
}
165
150
0 commit comments