@@ -695,38 +695,6 @@ def test_igemmlt_row_scale(dim1, dim4, inner):
695695 print (sum (err3 ) / len (err3 ))
696696
697697
698- @pytest .mark .parametrize ("dim1" , get_test_dims (2 , 1024 , n = 2 ), ids = id_formatter ("dim1" ))
699- @pytest .mark .parametrize ("dim2" , get_test_dims (2 , 1024 , n = 2 ), ids = id_formatter ("dim2" ))
700- @pytest .mark .parametrize ("dim3" , [0 ], ids = id_formatter ("dim3" ))
701- @pytest .mark .parametrize ("dims" , [2 ], ids = id_formatter ("dims" ))
702- @pytest .mark .parametrize ("dtype" , [torch .int8 ], ids = describe_dtype )
703- @pytest .mark .parametrize ("orderA" , ["row" ], ids = id_formatter ("orderA" ))
704- @pytest .mark .parametrize ("orderOut" , ["col32" , "col_turing" , "col_ampere" ], ids = id_formatter ("orderOut" ))
705- @pytest .mark .parametrize ("transpose" , TRUE_FALSE , ids = id_formatter ("transpose" ))
706- @pytest .mark .deprecated
707- def test_transform (dim1 , dim2 , dim3 , dims , dtype , orderA , orderOut , transpose ):
708- for i in range (k ):
709- if dims == 2 :
710- A = torch .randint (10 , 99 , size = (dim1 , dim2 ), device = "cuda" ).to (dtype )
711- elif dims == 3 :
712- A = torch .randint (10 , 99 , size = (dim1 , dim2 , dim3 ), device = "cuda" ).to (dtype )
713-
714- A .view (- 1 )[- 1 ] = - 1
715- if transpose :
716- At = A .t ().contiguous ()
717- out1 , S1 = F .nvidia_transform (At , to_order = orderOut )
718- else :
719- out1 , S1 = F .nvidia_transform (A , to_order = orderOut )
720- out2 , S2 = F .transform (A , to_order = orderOut , transpose = transpose )
721-
722- assert S1 [0 ][0 ] == S2 [0 ][0 ]
723- assert S1 [0 ][1 ] == S2 [0 ][1 ]
724- # print(out1)
725- # print(out2)
726-
727- torch .testing .assert_close (out1 , out2 )
728-
729-
730698@pytest .mark .parametrize ("dim1" , [512 , 2048 ], ids = id_formatter ("dim1" ))
731699@pytest .mark .parametrize ("dim2" , [1024 , 4096 ], ids = id_formatter ("dim2" ))
732700def test_coo_double_quant (dim1 , dim2 ):
@@ -1782,6 +1750,38 @@ def test_percentile_clipping(gtype):
17821750 torch .testing .assert_close (gnorm1 , gnorm2 )
17831751
17841752
1753+ @pytest .mark .parametrize ("dim1" , get_test_dims (2 , 1024 , n = 2 ), ids = id_formatter ("dim1" ))
1754+ @pytest .mark .parametrize ("dim2" , get_test_dims (2 , 1024 , n = 2 ), ids = id_formatter ("dim2" ))
1755+ @pytest .mark .parametrize ("dim3" , [0 ], ids = id_formatter ("dim3" ))
1756+ @pytest .mark .parametrize ("dims" , [2 ], ids = id_formatter ("dims" ))
1757+ @pytest .mark .parametrize ("dtype" , [torch .int8 ], ids = describe_dtype )
1758+ @pytest .mark .parametrize ("orderA" , ["row" ], ids = id_formatter ("orderA" ))
1759+ @pytest .mark .parametrize ("orderOut" , ["col32" , "col_turing" , "col_ampere" ], ids = id_formatter ("orderOut" ))
1760+ @pytest .mark .parametrize ("transpose" , TRUE_FALSE , ids = id_formatter ("transpose" ))
1761+ @pytest .mark .deprecated
1762+ def test_transform (dim1 , dim2 , dim3 , dims , dtype , orderA , orderOut , transpose ):
1763+ for i in range (k ):
1764+ if dims == 2 :
1765+ A = torch .randint (10 , 99 , size = (dim1 , dim2 ), device = "cuda" ).to (dtype )
1766+ elif dims == 3 :
1767+ A = torch .randint (10 , 99 , size = (dim1 , dim2 , dim3 ), device = "cuda" ).to (dtype )
1768+
1769+ A .view (- 1 )[- 1 ] = - 1
1770+ if transpose :
1771+ At = A .t ().contiguous ()
1772+ out1 , S1 = F .nvidia_transform (At , to_order = orderOut )
1773+ else :
1774+ out1 , S1 = F .nvidia_transform (A , to_order = orderOut )
1775+ out2 , S2 = F .transform (A , to_order = orderOut , transpose = transpose )
1776+
1777+ assert S1 [0 ][0 ] == S2 [0 ][0 ]
1778+ assert S1 [0 ][1 ] == S2 [0 ][1 ]
1779+ # print(out1)
1780+ # print(out2)
1781+
1782+ torch .testing .assert_close (out1 , out2 )
1783+
1784+
17851785@pytest .mark .parametrize ("dim1" , get_test_dims (2 , 256 , n = 2 ), ids = id_formatter ("dim1" ))
17861786@pytest .mark .parametrize ("dim2" , get_test_dims (2 , 256 , n = 2 ), ids = id_formatter ("dim2" ))
17871787@pytest .mark .parametrize ("dim3" , get_test_dims (2 , 256 , n = 2 ), ids = id_formatter ("dim3" ))
0 commit comments