@@ -15,6 +15,9 @@ limitations under the License. */
15
15
#include " paddle/fluid/operators/data_norm_op.h"
16
16
#include < string>
17
17
#include " paddle/fluid/framework/data_layout.h"
18
+ #ifdef PADDLE_WITH_MKLDNN
19
+ #include " paddle/fluid/platform/mkldnn_helper.h"
20
+ #endif
18
21
19
22
namespace paddle {
20
23
namespace operators {
@@ -94,6 +97,13 @@ class DataNormOp : public framework::OperatorWithKernel {
94
97
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
95
98
framework::LibraryType library = framework::LibraryType::kPlain ;
96
99
framework::DataLayout layout = framework::DataLayout::kAnyLayout ;
100
+ #ifdef PADDLE_WITH_MKLDNN
101
+ if (library == framework::LibraryType::kPlain &&
102
+ platform::CanMKLDNNBeUsed (ctx)) {
103
+ library = framework::LibraryType::kMKLDNN ;
104
+ layout = framework::DataLayout::kMKLDNN ;
105
+ }
106
+ #endif
97
107
98
108
return framework::OpKernelType (input_data_type, ctx.GetPlace (), layout,
99
109
library);
@@ -251,6 +261,14 @@ class DataNormGradOp : public framework::OperatorWithKernel {
251
261
framework::LibraryType library = framework::LibraryType::kPlain ;
252
262
framework::DataLayout layout = framework::DataLayout::kAnyLayout ;
253
263
264
+ #ifdef PADDLE_WITH_MKLDNN
265
+ if (library == framework::LibraryType::kPlain &&
266
+ platform::CanMKLDNNBeUsed (ctx)) {
267
+ library = framework::LibraryType::kMKLDNN ;
268
+ layout = framework::DataLayout::kMKLDNN ;
269
+ }
270
+ #endif
271
+
254
272
return framework::OpKernelType (ctx.Input <Tensor>(" X" )->type (),
255
273
ctx.GetPlace (), layout, library);
256
274
}
0 commit comments