Skip to content

Commit d74bb6a

Browse files
committed
fix ut for mkldnn 0.15 - added forcing layout NCHW in mkldnn conv tests
1 parent c144634 commit d74bb6a

File tree

4 files changed

+35
-10
lines changed

4 files changed

+35
-10
lines changed

paddle/fluid/operators/conv_mkldnn_op.cc

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -280,12 +280,16 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
280280
* ('any') which lets a primitive (convolution in this case) choose
281281
* the memory format preferred for best performance
282282
*/
283+
std::string data_format = ctx.Attr<std::string>("data_format");
284+
auto chosen_memory_format =
285+
platform::data_format_to_memory_format(data_format);
286+
283287
auto src_md = platform::MKLDNNMemDesc(
284-
src_tz, platform::MKLDNNGetDataType<T>(), memory::format::any);
288+
src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
285289
auto weights_md = platform::MKLDNNMemDesc(
286-
weights_tz, platform::MKLDNNGetDataType<T>(), memory::format::any);
290+
weights_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
287291
auto dst_md = platform::MKLDNNMemDesc(
288-
dst_tz, platform::MKLDNNGetDataType<T>(), memory::format::any);
292+
dst_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
289293

290294
// create a conv primitive descriptor and save it for usage in backward
291295
std::shared_ptr<mkldnn::convolution_forward::primitive_desc> conv_pd =
@@ -423,16 +427,20 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
423427
* ('any') which lets a primitive (conv backward in this case) choose
424428
* the memory format preferred for best performance
425429
*/
430+
std::string data_format = ctx.Attr<std::string>("data_format");
431+
auto chosen_memory_format =
432+
platform::data_format_to_memory_format(data_format);
433+
426434
auto src_md = platform::MKLDNNMemDesc(
427-
src_tz, platform::MKLDNNGetDataType<T>(), memory::format::any);
435+
src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
428436
auto diff_src_md = platform::MKLDNNMemDesc(
429-
src_tz, platform::MKLDNNGetDataType<T>(), memory::format::any);
437+
src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
430438
auto weights_md = platform::MKLDNNMemDesc(
431-
weights_tz, platform::MKLDNNGetDataType<T>(), memory::format::any);
439+
weights_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
432440
auto diff_weights_md = platform::MKLDNNMemDesc(
433-
weights_tz, platform::MKLDNNGetDataType<T>(), memory::format::any);
441+
weights_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
434442
auto diff_dst_md = platform::MKLDNNMemDesc(
435-
dst_tz, platform::MKLDNNGetDataType<T>(), memory::format::any);
443+
dst_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
436444

437445
// Retrieve conv_pd from device context
438446
auto conv_pd =

paddle/fluid/platform/mkldnn_helper.h

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ class MKLDNNHandler {
223223
static std::string GetHash(mkldnn::memory::dims& operand_dims, // NOLINT
224224
const std::string& suffix) {
225225
return dims2str(operand_dims) + suffix;
226-
};
226+
}
227227

228228
protected:
229229
static std::string dims2str(const mkldnn::memory::dims& operand_dims) {
@@ -251,5 +251,17 @@ inline mkldnn::memory::format MKLDNNFormatForSize(
251251
return data_format;
252252
}
253253

254+
inline mkldnn::memory::format data_format_to_memory_format(
255+
const std::string& data_format) {
256+
switch (framework::StringToDataLayout(data_format)) {
257+
case framework::DataLayout::kNHWC:
258+
return mkldnn::memory::format::nhwc;
259+
case framework::DataLayout::kNCHW:
260+
return mkldnn::memory::format::nchw;
261+
default:
262+
return mkldnn::memory::format::any;
263+
}
264+
}
265+
254266
} // namespace platform
255267
} // namespace paddle

python/paddle/fluid/tests/unittests/test_conv2d_mkldnn_op.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,19 @@
2020
class TestMKLDNN(TestConv2dOp):
2121
def init_kernel_type(self):
2222
self.use_mkldnn = True
23+
self.data_format = "NCHW"
2324

2425

2526
class TestMKLDNNWithPad(TestWithPad):
2627
def init_kernel_type(self):
2728
self.use_mkldnn = True
29+
self.data_format = "NCHW"
2830

2931

3032
class TestMKLDNNWithStride(TestWithStride):
3133
def init_kernel_type(self):
3234
self.use_mkldnn = True
35+
self.data_format = "NCHW"
3336

3437

3538
if __name__ == '__main__':

python/paddle/fluid/tests/unittests/test_conv2d_op.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def setUp(self):
6666
self.op_type = "conv2d"
6767
self.use_cudnn = False
6868
self.use_mkldnn = False
69+
self.data_format = "AnyLayout"
6970
self.dtype = np.float32
7071
self.init_kernel_type()
7172
self.init_group()
@@ -93,7 +94,8 @@ def setUp(self):
9394
'groups': self.groups,
9495
'dilations': self.dilations,
9596
'use_cudnn': self.use_cudnn,
96-
'use_mkldnn': self.use_mkldnn
97+
'use_mkldnn': self.use_mkldnn,
98+
'data_format': self.data_format
9799
}
98100
self.outputs = {'Output': output}
99101

0 commit comments

Comments
 (0)