Skip to content

Commit 5a9ae41

Browse files
authored
Merge pull request #12618 from sfraczek/sfraczek/fix-new-mkldnn-conv-tests
fix UT for mkldnn 0.15
2 parents cf799a6 + c2a437f commit 5a9ae41

File tree

4 files changed

+52
-51
lines changed

4 files changed

+52
-51
lines changed

paddle/fluid/operators/conv_mkldnn_op.cc

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -280,12 +280,16 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
280280
* ('any') which lets a primitive (convolution in this case) choose
281281
* the memory format preferred for best performance
282282
*/
283+
std::string data_format = ctx.Attr<std::string>("data_format");
284+
auto chosen_memory_format =
285+
platform::data_format_to_memory_format(data_format);
286+
283287
auto src_md = platform::MKLDNNMemDesc(
284-
src_tz, platform::MKLDNNGetDataType<T>(), memory::format::any);
288+
src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
285289
auto weights_md = platform::MKLDNNMemDesc(
286-
weights_tz, platform::MKLDNNGetDataType<T>(), memory::format::any);
290+
weights_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
287291
auto dst_md = platform::MKLDNNMemDesc(
288-
dst_tz, platform::MKLDNNGetDataType<T>(), memory::format::any);
292+
dst_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
289293

290294
// create a conv primitive descriptor and save it for usage in backward
291295
std::shared_ptr<mkldnn::convolution_forward::primitive_desc> conv_pd =
@@ -423,16 +427,20 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
423427
* ('any') which lets a primitive (conv backward in this case) choose
424428
* the memory format preferred for best performance
425429
*/
430+
std::string data_format = ctx.Attr<std::string>("data_format");
431+
auto chosen_memory_format =
432+
platform::data_format_to_memory_format(data_format);
433+
426434
auto src_md = platform::MKLDNNMemDesc(
427-
src_tz, platform::MKLDNNGetDataType<T>(), memory::format::any);
435+
src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
428436
auto diff_src_md = platform::MKLDNNMemDesc(
429-
src_tz, platform::MKLDNNGetDataType<T>(), memory::format::any);
437+
src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
430438
auto weights_md = platform::MKLDNNMemDesc(
431-
weights_tz, platform::MKLDNNGetDataType<T>(), memory::format::any);
439+
weights_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
432440
auto diff_weights_md = platform::MKLDNNMemDesc(
433-
weights_tz, platform::MKLDNNGetDataType<T>(), memory::format::any);
441+
weights_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
434442
auto diff_dst_md = platform::MKLDNNMemDesc(
435-
dst_tz, platform::MKLDNNGetDataType<T>(), memory::format::any);
443+
dst_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
436444

437445
// Retrieve conv_pd from device context
438446
auto conv_pd =

paddle/fluid/platform/mkldnn_helper.h

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ class MKLDNNHandler {
223223
static std::string GetHash(mkldnn::memory::dims& operand_dims, // NOLINT
224224
const std::string& suffix) {
225225
return dims2str(operand_dims) + suffix;
226-
};
226+
}
227227

228228
protected:
229229
static std::string dims2str(const mkldnn::memory::dims& operand_dims) {
@@ -251,5 +251,17 @@ inline mkldnn::memory::format MKLDNNFormatForSize(
251251
return data_format;
252252
}
253253

254+
inline mkldnn::memory::format data_format_to_memory_format(
255+
const std::string& data_format) {
256+
switch (framework::StringToDataLayout(data_format)) {
257+
case framework::DataLayout::kNHWC:
258+
return mkldnn::memory::format::nhwc;
259+
case framework::DataLayout::kNCHW:
260+
return mkldnn::memory::format::nchw;
261+
default:
262+
return mkldnn::memory::format::any;
263+
}
264+
}
265+
254266
} // namespace platform
255267
} // namespace paddle

python/paddle/fluid/tests/unittests/test_conv2d_mkldnn_op.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,19 @@
2020
class TestMKLDNN(TestConv2dOp):
2121
def init_kernel_type(self):
2222
self.use_mkldnn = True
23+
self.data_format = "NCHW"
2324

2425

2526
class TestMKLDNNWithPad(TestWithPad):
2627
def init_kernel_type(self):
2728
self.use_mkldnn = True
29+
self.data_format = "NCHW"
2830

2931

3032
class TestMKLDNNWithStride(TestWithStride):
3133
def init_kernel_type(self):
3234
self.use_mkldnn = True
35+
self.data_format = "NCHW"
3336

3437

3538
if __name__ == '__main__':

python/paddle/fluid/tests/unittests/test_conv2d_op.py

Lines changed: 20 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def setUp(self):
6666
self.op_type = "conv2d"
6767
self.use_cudnn = False
6868
self.use_mkldnn = False
69+
self.data_format = "AnyLayout"
6970
self.dtype = np.float32
7071
self.init_kernel_type()
7172
self.init_group()
@@ -93,67 +94,44 @@ def setUp(self):
9394
'groups': self.groups,
9495
'dilations': self.dilations,
9596
'use_cudnn': self.use_cudnn,
96-
'use_mkldnn': self.use_mkldnn
97+
'use_mkldnn': self.use_mkldnn,
98+
'data_format': self.data_format
9799
}
98100
self.outputs = {'Output': output}
99101

100102
def testcudnn(self):
101103
return core.is_compiled_with_cuda() and self.use_cudnn
102104

103105
def test_check_output(self):
104-
if self.testcudnn():
105-
place = core.CUDAPlace(0)
106-
self.check_output_with_place(place, atol=1e-5)
107-
else:
108-
self.check_output()
106+
place = core.CUDAPlace(0) if self.testcudnn() else core.CPUPlace()
107+
self.check_output_with_place(place, atol=1e-5)
109108

110109
def test_check_grad(self):
111110
if self.dtype == np.float16:
112111
return
113-
if self.testcudnn():
114-
place = core.CUDAPlace(0)
115-
self.check_grad_with_place(
116-
place,
117-
set(['Input', 'Filter']),
118-
'Output',
119-
max_relative_error=0.02)
120-
else:
121-
self.check_grad(
122-
set(['Input', 'Filter']), 'Output', max_relative_error=0.02)
112+
place = core.CUDAPlace(0) if self.testcudnn() else core.CPUPlace()
113+
self.check_grad_with_place(
114+
place, set(['Input', 'Filter']), 'Output', max_relative_error=0.02)
123115

124116
def test_check_grad_no_filter(self):
125117
if self.dtype == np.float16:
126118
return
127-
if self.testcudnn():
128-
place = core.CUDAPlace(0)
129-
self.check_grad_with_place(
130-
place, ['Input'],
131-
'Output',
132-
max_relative_error=0.02,
133-
no_grad_set=set(['Filter']))
134-
else:
135-
self.check_grad(
136-
['Input'],
137-
'Output',
138-
max_relative_error=0.02,
139-
no_grad_set=set(['Filter']))
119+
place = core.CUDAPlace(0) if self.testcudnn() else core.CPUPlace()
120+
self.check_grad_with_place(
121+
place, ['Input'],
122+
'Output',
123+
max_relative_error=0.02,
124+
no_grad_set=set(['Filter']))
140125

141126
def test_check_grad_no_input(self):
142127
if self.dtype == np.float16:
143128
return
144-
if self.testcudnn():
145-
place = core.CUDAPlace(0)
146-
self.check_grad_with_place(
147-
place, ['Filter'],
148-
'Output',
149-
max_relative_error=0.02,
150-
no_grad_set=set(['Input']))
151-
else:
152-
self.check_grad(
153-
['Filter'],
154-
'Output',
155-
max_relative_error=0.02,
156-
no_grad_set=set(['Input']))
129+
place = core.CUDAPlace(0) if self.testcudnn() else core.CPUPlace()
130+
self.check_grad_with_place(
131+
place, ['Filter'],
132+
'Output',
133+
max_relative_error=0.02,
134+
no_grad_set=set(['Input']))
157135

158136
def init_test_case(self):
159137
self.pad = [0, 0]

0 commit comments

Comments
 (0)