@@ -58,10 +58,20 @@ def evaluate_platform_supports_fp8():
58
58
59
59
def evaluate_platform_supports_mxfp8 ():
60
60
if torch .cuda .is_available ():
61
+ if torch .version .hip :
62
+ return False
61
63
return torch .cuda .get_device_capability () >= (10 , 0 )
62
64
return False
63
65
64
66
67
+ def evaluate_cuda_platform_version (major : int ):
68
+ if torch .version .cuda :
69
+ return torch .cuda .get_device_capability () >= (major , 0 )
70
+ return False
71
+
72
+
73
+ SM90_OR_LATER = evaluate_cuda_platform_version (9 )
74
+
65
75
SUPPORTS_FP8 = evaluate_platform_supports_fp8 ()
66
76
67
77
SUPPORTS_MXFP8 = evaluate_platform_supports_mxfp8 ()
@@ -1898,8 +1908,73 @@ def test_quantize_compile(self) -> None:
1898
1908
torch .compile (torch .ops .fbgemm .bf16_fast_gemv )(X , W_bf16 )
1899
1909
1900
1910
@unittest .skipIf (
1901
- not torch .version .cuda , "Skip on AMD: fast gemv op is not yet supported."
1911
+ torch .version .hip , "Skip on AMD: cuda quantize op is yet supported."
1912
+ )
1913
+ @settings (deadline = None )
1914
+ @given (
1915
+ K = st .sampled_from ([0 , 128 ]),
1902
1916
)
1917
+ def test_quantize_zero_input (self , K ) -> None :
1918
+ w = torch .randn (
1919
+ size = (0 , K ),
1920
+ dtype = torch .bfloat16 ,
1921
+ device = self .device ,
1922
+ )
1923
+ w_scale_ref = torch .empty (
1924
+ size = (0 ,),
1925
+ dtype = torch .float32 ,
1926
+ device = self .device ,
1927
+ )
1928
+ wq , w_scale = torch .ops .fbgemm .quantize_fp8_per_row (w )
1929
+ torch .testing .assert_close (w .shape , wq .shape )
1930
+ torch .testing .assert_close (w_scale .shape , w_scale_ref .shape )
1931
+
1932
+ @unittest .skipIf (torch .version .hip , "Skip on AMD: fp8 lite op is yet suported." )
1933
+ @settings (deadline = None )
1934
+ @given (
1935
+ M = st .sampled_from ([1 , 4 ]),
1936
+ N = st .sampled_from ([1024 , 6144 ]),
1937
+ K = st .sampled_from ([512 , 3584 ]),
1938
+ CudaGraph = st .sampled_from ([True , False ]),
1939
+ )
1940
+ def test_fp8_lite_matmul (self , M : int , N : int , K : int , CudaGraph : bool ) -> None :
1941
+ x = (
1942
+ torch .randn (
1943
+ size = (M , K ),
1944
+ dtype = torch .bfloat16 ,
1945
+ device = self .device ,
1946
+ )
1947
+ * 0.1
1948
+ )
1949
+ w = (
1950
+ torch .randn (
1951
+ size = (N , K ),
1952
+ dtype = torch .bfloat16 ,
1953
+ device = self .device ,
1954
+ )
1955
+ * 0.01
1956
+ )
1957
+ xq , x_scale = torch .ops .fbgemm .quantize_fp8_per_tensor (x )
1958
+ wq , w_scale = torch .ops .fbgemm .quantize_fp8_per_tensor (w )
1959
+ if CudaGraph :
1960
+ zq = torch .ops .fbgemm .f8f8bf16_lite (xq , wq , x_scale * w_scale )
1961
+ g = torch .cuda .CUDAGraph ()
1962
+ with torch .cuda .graph (g ):
1963
+ zq = torch .ops .fbgemm .f8f8bf16_lite (xq , wq , x_scale * w_scale )
1964
+ g .replay ()
1965
+ else :
1966
+ zq = torch .ops .fbgemm .f8f8bf16_lite (xq , wq , x_scale * w_scale )
1967
+ zq_ref = (x @ w .T ).to (torch .bfloat16 )
1968
+ torch .testing .assert_close (zq , zq_ref , atol = 9.0e-2 , rtol = 9.0e-2 )
1969
+
1970
+
1971
+ @unittest .skipIf (not torch .cuda .is_available (), "Skip when GPU is not available" )
1972
+ @unittest .skipIf (not SM90_OR_LATER , "Skip when not SM90+" )
1973
+ class FastGemvTests (unittest .TestCase ):
1974
+ @classmethod
1975
+ def setUpClass (cls ):
1976
+ cls .device = torch .accelerator .current_accelerator ()
1977
+
1903
1978
def run_gemv (
1904
1979
self , test_cases , gemv_op , atol , rtol , quantize_w = False , quantize_x = False
1905
1980
):
@@ -1933,9 +2008,6 @@ def run_gemv(
1933
2008
z_ref = (x @ w .T ).to (torch .bfloat16 ).to (self .device )
1934
2009
torch .testing .assert_close (z , z_ref , atol = atol , rtol = rtol )
1935
2010
1936
- @unittest .skipIf (
1937
- not torch .version .cuda , "Skip on AMD: fast gemv op is not yet supported."
1938
- )
1939
2011
def run_gemv_batched (self , test_cases , gemv_op , atol , rtol ):
1940
2012
for B , M , N , K in test_cases :
1941
2013
x = (
@@ -1964,9 +2036,6 @@ def run_gemv_batched(self, test_cases, gemv_op, atol, rtol):
1964
2036
z_ref = torch .bmm (x , w .transpose (1 , 2 )).to (torch .bfloat16 ).to (self .device )
1965
2037
torch .testing .assert_close (z , z_ref , atol = atol , rtol = rtol )
1966
2038
1967
- @unittest .skipIf (
1968
- not torch .version .cuda , "Skip on AMD: fast gemv op is not yet supported."
1969
- )
1970
2039
def test_bf16_gemv (self ) -> None :
1971
2040
test_cases = [
1972
2041
(1 , 128 , 256 ),
@@ -1990,9 +2059,6 @@ def test_bf16_gemv(self) -> None:
1990
2059
]
1991
2060
self .run_gemv (test_cases , torch .ops .fbgemm .bf16_fast_gemv , 9.0e-3 , 9.0e-3 )
1992
2061
1993
- @unittest .skipIf (
1994
- not torch .version .cuda , "Skip on AMD: fast gemv op is not yet supported."
1995
- )
1996
2062
def test_bf16_fp8_gemv (self ) -> None :
1997
2063
test_cases = [
1998
2064
(1 , 1280 , 8192 ),
@@ -2016,9 +2082,6 @@ def test_bf16_fp8_gemv(self) -> None:
2016
2082
quantize_w = True ,
2017
2083
)
2018
2084
2019
- @unittest .skipIf (
2020
- not torch .version .cuda , "Skip on AMD: fast gemv op is not yet supported."
2021
- )
2022
2085
def test_fp8_fp8_gemv (self ) -> None :
2023
2086
test_cases = [
2024
2087
(1 , 1280 , 8192 ),
@@ -2055,9 +2118,6 @@ def test_fp8_fp8_gemv(self) -> None:
2055
2118
quantize_x = True ,
2056
2119
)
2057
2120
2058
- @unittest .skipIf (
2059
- not torch .version .cuda , "Skip on AMD: fast gemv op is not yet supported."
2060
- )
2061
2121
def test_fp8_gemv_batched (self ) -> None :
2062
2122
test_cases = [
2063
2123
(2 , 1 , 4096 , 5120 ),
@@ -2082,66 +2142,6 @@ def test_fp8_gemv_batched(self) -> None:
2082
2142
1.0e-1 ,
2083
2143
)
2084
2144
2085
- @unittest .skipIf (
2086
- torch .version .hip , "Skip on AMD: cuda quantize op is yet supported."
2087
- )
2088
- @settings (deadline = None )
2089
- @given (
2090
- K = st .sampled_from ([0 , 128 ]),
2091
- )
2092
- def test_quantize_zero_input (self , K ) -> None :
2093
- w = torch .randn (
2094
- size = (0 , K ),
2095
- dtype = torch .bfloat16 ,
2096
- device = self .device ,
2097
- )
2098
- w_scale_ref = torch .empty (
2099
- size = (0 ,),
2100
- dtype = torch .float32 ,
2101
- device = self .device ,
2102
- )
2103
- wq , w_scale = torch .ops .fbgemm .quantize_fp8_per_row (w )
2104
- torch .testing .assert_close (w .shape , wq .shape )
2105
- torch .testing .assert_close (w_scale .shape , w_scale_ref .shape )
2106
-
2107
- @unittest .skipIf (torch .version .hip , "Skip on AMD: fp8 lite op is yet suported." )
2108
- @settings (deadline = None )
2109
- @given (
2110
- M = st .sampled_from ([1 , 4 ]),
2111
- N = st .sampled_from ([1024 , 6144 ]),
2112
- K = st .sampled_from ([512 , 3584 ]),
2113
- CudaGraph = st .sampled_from ([True , False ]),
2114
- )
2115
- def test_fp8_lite_matmul (self , M : int , N : int , K : int , CudaGraph : bool ) -> None :
2116
- x = (
2117
- torch .randn (
2118
- size = (M , K ),
2119
- dtype = torch .bfloat16 ,
2120
- device = self .device ,
2121
- )
2122
- * 0.1
2123
- )
2124
- w = (
2125
- torch .randn (
2126
- size = (N , K ),
2127
- dtype = torch .bfloat16 ,
2128
- device = self .device ,
2129
- )
2130
- * 0.01
2131
- )
2132
- xq , x_scale = torch .ops .fbgemm .quantize_fp8_per_tensor (x )
2133
- wq , w_scale = torch .ops .fbgemm .quantize_fp8_per_tensor (w )
2134
- if CudaGraph :
2135
- zq = torch .ops .fbgemm .f8f8bf16_lite (xq , wq , x_scale * w_scale )
2136
- g = torch .cuda .CUDAGraph ()
2137
- with torch .cuda .graph (g ):
2138
- zq = torch .ops .fbgemm .f8f8bf16_lite (xq , wq , x_scale * w_scale )
2139
- g .replay ()
2140
- else :
2141
- zq = torch .ops .fbgemm .f8f8bf16_lite (xq , wq , x_scale * w_scale )
2142
- zq_ref = (x @ w .T ).to (torch .bfloat16 )
2143
- torch .testing .assert_close (zq , zq_ref , atol = 9.0e-2 , rtol = 9.0e-2 )
2144
-
2145
2145
2146
2146
@unittest .skipIf (
2147
2147
not torch .cuda .is_available () or torch .version .hip ,
0 commit comments