@@ -22,8 +22,8 @@ class TestMulOp(OpTest):
22
22
def setUp (self ):
23
23
self .op_type = "mul"
24
24
self .inputs = {
25
- 'X' : np .random .random ((32 , 84 )).astype ("float32" ),
26
- 'Y' : np .random .random ((84 , 100 )).astype ("float32" )
25
+ 'X' : np .random .random ((2 , 5 )).astype ("float32" ),
26
+ 'Y' : np .random .random ((5 , 3 )).astype ("float32" )
27
27
}
28
28
self .outputs = {'Out' : np .dot (self .inputs ['X' ], self .inputs ['Y' ])}
29
29
@@ -46,13 +46,16 @@ class TestMulOp2(OpTest):
46
46
def setUp (self ):
47
47
self .op_type = "mul"
48
48
self .inputs = {
49
- 'X' : np .random .random ((15 , 4 , 12 , 10 )).astype ("float32" ),
50
- 'Y' : np .random .random ((4 , 30 , 8 , 2 , 9 )).astype ("float32" )
49
+ 'X' : np .random .random ((3 , 4 , 4 , 3 )).astype ("float32" ),
50
+ 'Y' : np .random .random ((2 , 6 , 1 , 2 , 3 )).astype ("float32" )
51
51
}
52
- self .attrs = {'x_num_col_dims' : 2 , 'y_num_col_dims' : 2 }
53
- result = np .dot (self .inputs ['X' ].reshape (15 * 4 , 12 * 10 ),
54
- self .inputs ['Y' ].reshape (4 * 30 , 8 * 2 * 9 ))
55
- result = result .reshape (15 , 4 , 8 , 2 , 9 )
52
+ self .attrs = {
53
+ 'x_num_col_dims' : 2 ,
54
+ 'y_num_col_dims' : 2 ,
55
+ }
56
+ result = np .dot (self .inputs ['X' ].reshape (3 * 4 , 4 * 3 ),
57
+ self .inputs ['Y' ].reshape (2 * 6 , 1 * 2 * 3 ))
58
+ result = result .reshape (3 , 4 , 1 , 2 , 3 )
56
59
self .outputs = {'Out' : result }
57
60
58
61
def test_check_output (self ):
@@ -73,9 +76,9 @@ def test_check_grad_ignore_y(self):
73
76
class TestFP16MulOp1 (OpTest ):
74
77
def setUp (self ):
75
78
self .op_type = "mul"
76
- x = np .random .random ((32 , 84 )).astype ("float16" )
77
- y = np .random .random ((84 , 100 )).astype ("float16" )
78
- self .inputs = {'X' : x .view (np .uint16 ), 'Y' : y .view (np .uint16 )}
79
+ x = np .random .random ((3 , 5 )).astype ("float16" )
80
+ y = np .random .random ((5 , 4 )).astype ("float16" )
81
+ self .inputs = {'X' : x .view (np .float16 ), 'Y' : y .view (np .float16 )}
79
82
self .outputs = {'Out' : np .dot (x , y )}
80
83
81
84
def test_check_output (self ):
@@ -88,13 +91,15 @@ def test_check_output(self):
88
91
class TestFP16MulOp2 (OpTest ):
89
92
def setUp (self ):
90
93
self .op_type = "mul"
91
- x = np .random .random ((15 , 4 , 12 , 10 )).astype ("float16" )
92
- y = np .random .random ((4 , 30 , 8 , 2 , 9 )).astype ("float16" )
93
- self .inputs = {'X' : x .view (np .uint16 ), 'Y' : y .view (np .uint16 )}
94
- self .attrs = {'x_num_col_dims' : 2 , 'y_num_col_dims' : 2 }
95
- result = np .dot (
96
- x .reshape (15 * 4 , 12 * 10 ), y .reshape (4 * 30 , 8 * 2 * 9 ))
97
- result = result .reshape (15 , 4 , 8 , 2 , 9 )
94
+ x = np .random .random ((3 , 4 , 4 , 3 )).astype ("float16" )
95
+ y = np .random .random ((2 , 6 , 1 , 2 , 3 )).astype ("float16" )
96
+ self .inputs = {'X' : x .view (np .float16 ), 'Y' : y .view (np .float16 )}
97
+ self .attrs = {
98
+ 'x_num_col_dims' : 2 ,
99
+ 'y_num_col_dims' : 2 ,
100
+ }
101
+ result = np .dot (x .reshape (3 * 4 , 4 * 3 ), y .reshape (2 * 6 , 1 * 2 * 3 ))
102
+ result = result .reshape (3 , 4 , 1 , 2 , 3 )
98
103
self .outputs = {'Out' : result }
99
104
100
105
def test_check_output (self ):
0 commit comments