@@ -1804,29 +1804,29 @@ def fn(a, b):
18041804 def test_mixed_mm (self ):
18051805 def fn (a , b ):
18061806 return torch .mm (a , b .to (a .dtype ))
1807- self .common (
1808- fn ,
1809- (
1810- torch .randn (8 , 8 ),
1811- torch .randint (- 128 , 127 , (8 , 8 ), dtype = torch .int8 ),
1812- ),
1813- check_lowp = True ,
1814- )
1807+ self .common (
1808+ fn ,
1809+ (
1810+ torch .randn (8 , 8 ),
1811+ torch .randint (- 128 , 127 , (8 , 8 ), dtype = torch .int8 ),
1812+ ),
1813+ check_lowp = True ,
1814+ )
18151815
18161816 @config .patch (force_mixed_mm = True )
18171817 def test_mixed_mm2 (self ):
18181818 def fn (a , b , scale , bias ):
18191819 return torch .mm (a , b .to (a .dtype )) * scale + bias
1820- self .common (
1821- fn ,
1822- (
1823- torch .randn (8 , 8 ),
1824- torch .randint (- 128 , 127 , (8 , 8 ), dtype = torch .int8 ),
1825- torch .randn (8 ),
1826- torch .randn (8 ),
1827- ),
1828- check_lowp = True ,
1829- )
1820+ self .common (
1821+ fn ,
1822+ (
1823+ torch .randn (8 , 8 ),
1824+ torch .randint (- 128 , 127 , (8 , 8 ), dtype = torch .int8 ),
1825+ torch .randn (8 ),
1826+ torch .randn (8 ),
1827+ ),
1828+ check_lowp = True ,
1829+ )
18301830
18311831 @config .patch (use_mixed_mm = True )
18321832 def test_uint4x2_mixed_mm (self ):
@@ -1839,14 +1839,14 @@ def fn(a, b):
18391839 .sub (8 ),
18401840 )
18411841
1842- self .common (
1843- fn ,
1844- (
1845- torch .randn (8 , 8 ),
1846- torch .randint (0 , 255 , (4 , 8 ), dtype = torch .uint8 ),
1847- ),
1848- check_lowp = True ,
1849- )
1842+ self .common (
1843+ fn ,
1844+ (
1845+ torch .randn (8 , 8 ),
1846+ torch .randint (0 , 255 , (4 , 8 ), dtype = torch .uint8 ),
1847+ ),
1848+ check_lowp = True ,
1849+ )
18501850
18511851 def test_scalar_input (self ):
18521852 def fn (x , y ):
@@ -6265,7 +6265,7 @@ def test_rsqrt_dynamic_shapes(self):
62656265 def fn (a , b ):
62666266 r = 1 / math .sqrt (a .size (1 ))
62676267 return torch .bmm (a , b ) / r
6268- return ( r ,)
6268+
62696269
62706270 self .common (
62716271 fn ,
0 commit comments