@@ -49,7 +49,37 @@ def test_check_grad_ingore_x(self):
49
49
def test_check_grad_ingore_y (self ):
50
50
pass
51
51
52
- class TestElementwiseMulMKLDNNOp_UnsupportedFormat (ElementwiseMulOp ):
52
+ @unittest .skip ("Not implemented yet." )
53
+ class TestElementwiseMulMKLDNNOp_BroadcastNCHW8c (ElementwiseMulOp ):
54
+ def init_input_output (self ):
55
+ x = np .random .rand (1 , 8 , 2 , 2 ).astype (self .dtype )
56
+ self .x = x .transpose (0 , 2 , 3 , 1 ).reshape (1 , 8 , 2 , 2 )
57
+ self .y = np .random .rand (1 , 8 ).astype (self .dtype )
58
+
59
+ self .out = x * self .y .reshape (1 , 8 , 1 , 1 )
60
+ self .out = self .out .transpose (0 , 2 , 3 , 1 ).reshape (1 , 8 , 2 , 2 )
61
+
62
+ def setUp (self ):
63
+ super (TestElementwiseMulMKLDNNOp_BroadcastNCHW8c , self ).setUp ()
64
+ self .attrs ["x_data_format" ] = "nchw8c"
65
+ self .attrs ["y_data_format" ] = "nc"
66
+
67
+ def init_kernel_type (self ):
68
+ self .use_mkldnn = True
69
+
70
+ def init_axis (self ):
71
+ self .axis = 0
72
+
73
+ def test_check_grad_normal (self ):
74
+ pass
75
+
76
+ def test_check_grad_ingore_x (self ):
77
+ pass
78
+
79
+ def test_check_grad_ingore_y (self ):
80
+ pass
81
+
82
+ class TestElementwiseMulMKLDNNOp_FallbackNCHW (ElementwiseMulOp ):
53
83
def init_input_output (self ):
54
84
self .x = np .random .rand (1 , 16 , 2 , 2 ).astype (self .dtype )
55
85
self .y = np .random .rand (1 , 16 ).astype (self .dtype )
@@ -71,5 +101,92 @@ def test_check_grad_ingore_x(self):
71
101
def test_check_grad_ingore_y (self ):
72
102
pass
73
103
104
+ class TestElementwiseMulMKLDNNOp_FallbackNCHW16C (ElementwiseMulOp ):
105
+ def init_input_output (self ):
106
+ x = np .random .rand (1 , 16 , 2 , 2 ).astype (self .dtype )
107
+ self .x = x .transpose (0 , 2 , 3 , 1 ).reshape (1 , 16 , 2 , 2 )
108
+ y = np .random .rand (1 , 16 , 2 , 2 ).astype (self .dtype )
109
+ self .y = y .transpose (0 , 2 , 3 , 1 ).reshape (1 , 16 , 2 , 2 )
110
+
111
+ self .out = self .x * self .y
112
+
113
+ def setUp (self ):
114
+ super (TestElementwiseMulMKLDNNOp_FallbackNCHW16C , self ).setUp ()
115
+ self .attrs ["x_data_format" ] = "nchw16c"
116
+ self .attrs ["y_data_format" ] = "nchw16c"
117
+
118
+ def init_kernel_type (self ):
119
+ self .use_mkldnn = True
120
+
121
+ def init_axis (self ):
122
+ self .axis = 0
123
+
124
+ def test_check_grad_normal (self ):
125
+ pass
126
+
127
+ def test_check_grad_ingore_x (self ):
128
+ pass
129
+
130
+ def test_check_grad_ingore_y (self ):
131
+ pass
132
+
133
+ class TestElementwiseMulMKLDNNOp_FallbackNoReorders (ElementwiseMulOp ):
134
+ def init_input_output (self ):
135
+ x = np .random .rand (1 , 16 , 2 , 2 ).astype (self .dtype )
136
+ self .x = x .transpose (0 , 2 , 3 , 1 ).reshape (1 , 16 , 2 , 2 )
137
+ y = np .random .rand (1 , 16 , 2 , 2 ).astype (self .dtype )
138
+ self .y = y .transpose (0 , 2 , 3 , 1 ).reshape (1 , 16 , 2 , 2 )
139
+
140
+ self .out = self .x * self .y
141
+
142
+ def setUp (self ):
143
+ super (TestElementwiseMulMKLDNNOp_FallbackNoReorders , self ).setUp ()
144
+ self .attrs ["x_data_format" ] = "nchw16c"
145
+ self .attrs ["y_data_format" ] = "nchw16c"
146
+
147
+ def init_kernel_type (self ):
148
+ self .use_mkldnn = True
149
+
150
+ def init_axis (self ):
151
+ self .axis = 0
152
+
153
+ def test_check_grad_normal (self ):
154
+ pass
155
+
156
+ def test_check_grad_ingore_x (self ):
157
+ pass
158
+
159
+ def test_check_grad_ingore_y (self ):
160
+ pass
161
+
162
+ @unittest .skip ("Not implemented yet." )
163
+ class TestElementwiseMulMKLDNNOp_FallbackWithReorder (ElementwiseMulOp ):
164
+ def init_input_output (self ):
165
+ self .x = np .random .rand (1 , 16 , 2 , 2 ).astype (self .dtype )
166
+ y = np .random .rand (1 , 16 , 2 , 2 ).astype (self .dtype )
167
+ self .y = y .transpose (0 , 2 , 3 , 1 ).reshape (1 , 16 , 2 , 2 )
168
+
169
+ self .out = self .x * y
170
+
171
+ def setUp (self ):
172
+ super (TestElementwiseMulMKLDNNOp_FallbackNCHW16C , self ).setUp ()
173
+ self .attrs ["x_data_format" ] = "nchw"
174
+ self .attrs ["y_data_format" ] = "nchw16c"
175
+
176
+ def init_kernel_type (self ):
177
+ self .use_mkldnn = True
178
+
179
+ def init_axis (self ):
180
+ self .axis = 0
181
+
182
+ def test_check_grad_normal (self ):
183
+ pass
184
+
185
+ def test_check_grad_ingore_x (self ):
186
+ pass
187
+
188
+ def test_check_grad_ingore_y (self ):
189
+ pass
190
+
74
191
if __name__ == '__main__' :
75
192
unittest .main ()
0 commit comments