Skip to content

Commit f820573

Browse files
committed
MKLDNN elementwise_mul: Add UTs
1 parent d14858e commit f820573

File tree

1 file changed

+118
-1
lines changed

1 file changed

+118
-1
lines changed

python/paddle/fluid/tests/unittests/test_elementwise_mul_mkldnn_op.py

Lines changed: 118 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,37 @@ def test_check_grad_ingore_x(self):
4949
def test_check_grad_ingore_y(self):
5050
pass
5151

52-
class TestElementwiseMulMKLDNNOp_UnsupportedFormat(ElementwiseMulOp):
52+
@unittest.skip("Not implemented yet.")
53+
class TestElementwiseMulMKLDNNOp_BroadcastNCHW8c(ElementwiseMulOp):
54+
def init_input_output(self):
55+
x = np.random.rand(1, 8, 2, 2).astype(self.dtype)
56+
self.x = x.transpose(0, 2, 3, 1).reshape(1, 8, 2, 2)
57+
self.y = np.random.rand(1, 8).astype(self.dtype)
58+
59+
self.out = x * self.y.reshape(1, 8, 1, 1)
60+
self.out = self.out.transpose(0, 2, 3, 1).reshape(1, 8, 2, 2)
61+
62+
def setUp(self):
63+
super(TestElementwiseMulMKLDNNOp_BroadcastNCHW8c, self).setUp()
64+
self.attrs["x_data_format"] = "nchw8c"
65+
self.attrs["y_data_format"] = "nc"
66+
67+
def init_kernel_type(self):
68+
self.use_mkldnn = True
69+
70+
def init_axis(self):
71+
self.axis = 0
72+
73+
def test_check_grad_normal(self):
74+
pass
75+
76+
def test_check_grad_ingore_x(self):
77+
pass
78+
79+
def test_check_grad_ingore_y(self):
80+
pass
81+
82+
class TestElementwiseMulMKLDNNOp_FallbackNCHW(ElementwiseMulOp):
5383
def init_input_output(self):
5484
self.x = np.random.rand(1, 16, 2, 2).astype(self.dtype)
5585
self.y = np.random.rand(1, 16).astype(self.dtype)
@@ -71,5 +101,92 @@ def test_check_grad_ingore_x(self):
71101
def test_check_grad_ingore_y(self):
72102
pass
73103

104+
class TestElementwiseMulMKLDNNOp_FallbackNCHW16C(ElementwiseMulOp):
105+
def init_input_output(self):
106+
x = np.random.rand(1, 16, 2, 2).astype(self.dtype)
107+
self.x = x.transpose(0, 2, 3, 1).reshape(1, 16, 2, 2)
108+
y = np.random.rand(1, 16, 2, 2).astype(self.dtype)
109+
self.y = y.transpose(0, 2, 3, 1).reshape(1, 16, 2, 2)
110+
111+
self.out = self.x * self.y
112+
113+
def setUp(self):
114+
super(TestElementwiseMulMKLDNNOp_FallbackNCHW16C, self).setUp()
115+
self.attrs["x_data_format"] = "nchw16c"
116+
self.attrs["y_data_format"] = "nchw16c"
117+
118+
def init_kernel_type(self):
119+
self.use_mkldnn = True
120+
121+
def init_axis(self):
122+
self.axis = 0
123+
124+
def test_check_grad_normal(self):
125+
pass
126+
127+
def test_check_grad_ingore_x(self):
128+
pass
129+
130+
def test_check_grad_ingore_y(self):
131+
pass
132+
133+
class TestElementwiseMulMKLDNNOp_FallbackNoReorders(ElementwiseMulOp):
134+
def init_input_output(self):
135+
x = np.random.rand(1, 16, 2, 2).astype(self.dtype)
136+
self.x = x.transpose(0, 2, 3, 1).reshape(1, 16, 2, 2)
137+
y = np.random.rand(1, 16, 2, 2).astype(self.dtype)
138+
self.y = y.transpose(0, 2, 3, 1).reshape(1, 16, 2, 2)
139+
140+
self.out = self.x * self.y
141+
142+
def setUp(self):
143+
super(TestElementwiseMulMKLDNNOp_FallbackNoReorders, self).setUp()
144+
self.attrs["x_data_format"] = "nchw16c"
145+
self.attrs["y_data_format"] = "nchw16c"
146+
147+
def init_kernel_type(self):
148+
self.use_mkldnn = True
149+
150+
def init_axis(self):
151+
self.axis = 0
152+
153+
def test_check_grad_normal(self):
154+
pass
155+
156+
def test_check_grad_ingore_x(self):
157+
pass
158+
159+
def test_check_grad_ingore_y(self):
160+
pass
161+
162+
@unittest.skip("Not implemented yet.")
163+
class TestElementwiseMulMKLDNNOp_FallbackWithReorder(ElementwiseMulOp):
164+
def init_input_output(self):
165+
self.x = np.random.rand(1, 16, 2, 2).astype(self.dtype)
166+
y = np.random.rand(1, 16, 2, 2).astype(self.dtype)
167+
self.y = y.transpose(0, 2, 3, 1).reshape(1, 16, 2, 2)
168+
169+
self.out = self.x * y
170+
171+
def setUp(self):
172+
super(TestElementwiseMulMKLDNNOp_FallbackNCHW16C, self).setUp()
173+
self.attrs["x_data_format"] = "nchw"
174+
self.attrs["y_data_format"] = "nchw16c"
175+
176+
def init_kernel_type(self):
177+
self.use_mkldnn = True
178+
179+
def init_axis(self):
180+
self.axis = 0
181+
182+
def test_check_grad_normal(self):
183+
pass
184+
185+
def test_check_grad_ingore_x(self):
186+
pass
187+
188+
def test_check_grad_ingore_y(self):
189+
pass
190+
74191
if __name__ == '__main__':
75192
unittest.main()

0 commit comments

Comments
 (0)