Skip to content

Commit 23f944c

Browse files
cthifacebook-github-bot
authored andcommitted
Split out fast gemv test into its own class (#4852)
Summary: Pull Request resolved: #4852 X-link: facebookresearch/FBGEMM#1878 Fast gemv tests are broken, not sure since when. Split them into their own class so we could avoid running them later in the new OSS CI. The numerics could be suspect potentially, some of the atol is quite high for a small number of values, but didn't debug it. Since this is not used, let's ignore for now. Wiht this change we only have 1 broken tests in FP8Tests, apparently stochastic rounding is broken for FP8 Rowwise.. Reviewed By: q10 Differential Revision: D82115129 fbshipit-source-id: 03aae5eda14635fa3372665529508e98eabdde61
1 parent 5efeb0f commit 23f944c

File tree

1 file changed

+76
-76
lines changed

1 file changed

+76
-76
lines changed

fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py

Lines changed: 76 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,20 @@ def evaluate_platform_supports_fp8():
5858

5959
def evaluate_platform_supports_mxfp8():
6060
if torch.cuda.is_available():
61+
if torch.version.hip:
62+
return False
6163
return torch.cuda.get_device_capability() >= (10, 0)
6264
return False
6365

6466

67+
def evaluate_cuda_platform_version(major: int):
68+
if torch.version.cuda:
69+
return torch.cuda.get_device_capability() >= (major, 0)
70+
return False
71+
72+
73+
SM90_OR_LATER = evaluate_cuda_platform_version(9)
74+
6575
SUPPORTS_FP8 = evaluate_platform_supports_fp8()
6676

6777
SUPPORTS_MXFP8 = evaluate_platform_supports_mxfp8()
@@ -1898,8 +1908,73 @@ def test_quantize_compile(self) -> None:
18981908
torch.compile(torch.ops.fbgemm.bf16_fast_gemv)(X, W_bf16)
18991909

19001910
@unittest.skipIf(
1901-
not torch.version.cuda, "Skip on AMD: fast gemv op is not yet supported."
1911+
torch.version.hip, "Skip on AMD: cuda quantize op is yet supported."
1912+
)
1913+
@settings(deadline=None)
1914+
@given(
1915+
K=st.sampled_from([0, 128]),
19021916
)
1917+
def test_quantize_zero_input(self, K) -> None:
1918+
w = torch.randn(
1919+
size=(0, K),
1920+
dtype=torch.bfloat16,
1921+
device=self.device,
1922+
)
1923+
w_scale_ref = torch.empty(
1924+
size=(0,),
1925+
dtype=torch.float32,
1926+
device=self.device,
1927+
)
1928+
wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_row(w)
1929+
torch.testing.assert_close(w.shape, wq.shape)
1930+
torch.testing.assert_close(w_scale.shape, w_scale_ref.shape)
1931+
1932+
@unittest.skipIf(torch.version.hip, "Skip on AMD: fp8 lite op is yet suported.")
1933+
@settings(deadline=None)
1934+
@given(
1935+
M=st.sampled_from([1, 4]),
1936+
N=st.sampled_from([1024, 6144]),
1937+
K=st.sampled_from([512, 3584]),
1938+
CudaGraph=st.sampled_from([True, False]),
1939+
)
1940+
def test_fp8_lite_matmul(self, M: int, N: int, K: int, CudaGraph: bool) -> None:
1941+
x = (
1942+
torch.randn(
1943+
size=(M, K),
1944+
dtype=torch.bfloat16,
1945+
device=self.device,
1946+
)
1947+
* 0.1
1948+
)
1949+
w = (
1950+
torch.randn(
1951+
size=(N, K),
1952+
dtype=torch.bfloat16,
1953+
device=self.device,
1954+
)
1955+
* 0.01
1956+
)
1957+
xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(x)
1958+
wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(w)
1959+
if CudaGraph:
1960+
zq = torch.ops.fbgemm.f8f8bf16_lite(xq, wq, x_scale * w_scale)
1961+
g = torch.cuda.CUDAGraph()
1962+
with torch.cuda.graph(g):
1963+
zq = torch.ops.fbgemm.f8f8bf16_lite(xq, wq, x_scale * w_scale)
1964+
g.replay()
1965+
else:
1966+
zq = torch.ops.fbgemm.f8f8bf16_lite(xq, wq, x_scale * w_scale)
1967+
zq_ref = (x @ w.T).to(torch.bfloat16)
1968+
torch.testing.assert_close(zq, zq_ref, atol=9.0e-2, rtol=9.0e-2)
1969+
1970+
1971+
@unittest.skipIf(not torch.cuda.is_available(), "Skip when GPU is not available")
1972+
@unittest.skipIf(not SM90_OR_LATER, "Skip when not SM90+")
1973+
class FastGemvTests(unittest.TestCase):
1974+
@classmethod
1975+
def setUpClass(cls):
1976+
cls.device = torch.accelerator.current_accelerator()
1977+
19031978
def run_gemv(
19041979
self, test_cases, gemv_op, atol, rtol, quantize_w=False, quantize_x=False
19051980
):
@@ -1933,9 +2008,6 @@ def run_gemv(
19332008
z_ref = (x @ w.T).to(torch.bfloat16).to(self.device)
19342009
torch.testing.assert_close(z, z_ref, atol=atol, rtol=rtol)
19352010

1936-
@unittest.skipIf(
1937-
not torch.version.cuda, "Skip on AMD: fast gemv op is not yet supported."
1938-
)
19392011
def run_gemv_batched(self, test_cases, gemv_op, atol, rtol):
19402012
for B, M, N, K in test_cases:
19412013
x = (
@@ -1964,9 +2036,6 @@ def run_gemv_batched(self, test_cases, gemv_op, atol, rtol):
19642036
z_ref = torch.bmm(x, w.transpose(1, 2)).to(torch.bfloat16).to(self.device)
19652037
torch.testing.assert_close(z, z_ref, atol=atol, rtol=rtol)
19662038

1967-
@unittest.skipIf(
1968-
not torch.version.cuda, "Skip on AMD: fast gemv op is not yet supported."
1969-
)
19702039
def test_bf16_gemv(self) -> None:
19712040
test_cases = [
19722041
(1, 128, 256),
@@ -1990,9 +2059,6 @@ def test_bf16_gemv(self) -> None:
19902059
]
19912060
self.run_gemv(test_cases, torch.ops.fbgemm.bf16_fast_gemv, 9.0e-3, 9.0e-3)
19922061

1993-
@unittest.skipIf(
1994-
not torch.version.cuda, "Skip on AMD: fast gemv op is not yet supported."
1995-
)
19962062
def test_bf16_fp8_gemv(self) -> None:
19972063
test_cases = [
19982064
(1, 1280, 8192),
@@ -2016,9 +2082,6 @@ def test_bf16_fp8_gemv(self) -> None:
20162082
quantize_w=True,
20172083
)
20182084

2019-
@unittest.skipIf(
2020-
not torch.version.cuda, "Skip on AMD: fast gemv op is not yet supported."
2021-
)
20222085
def test_fp8_fp8_gemv(self) -> None:
20232086
test_cases = [
20242087
(1, 1280, 8192),
@@ -2055,9 +2118,6 @@ def test_fp8_fp8_gemv(self) -> None:
20552118
quantize_x=True,
20562119
)
20572120

2058-
@unittest.skipIf(
2059-
not torch.version.cuda, "Skip on AMD: fast gemv op is not yet supported."
2060-
)
20612121
def test_fp8_gemv_batched(self) -> None:
20622122
test_cases = [
20632123
(2, 1, 4096, 5120),
@@ -2082,66 +2142,6 @@ def test_fp8_gemv_batched(self) -> None:
20822142
1.0e-1,
20832143
)
20842144

2085-
@unittest.skipIf(
2086-
torch.version.hip, "Skip on AMD: cuda quantize op is yet supported."
2087-
)
2088-
@settings(deadline=None)
2089-
@given(
2090-
K=st.sampled_from([0, 128]),
2091-
)
2092-
def test_quantize_zero_input(self, K) -> None:
2093-
w = torch.randn(
2094-
size=(0, K),
2095-
dtype=torch.bfloat16,
2096-
device=self.device,
2097-
)
2098-
w_scale_ref = torch.empty(
2099-
size=(0,),
2100-
dtype=torch.float32,
2101-
device=self.device,
2102-
)
2103-
wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_row(w)
2104-
torch.testing.assert_close(w.shape, wq.shape)
2105-
torch.testing.assert_close(w_scale.shape, w_scale_ref.shape)
2106-
2107-
@unittest.skipIf(torch.version.hip, "Skip on AMD: fp8 lite op is yet suported.")
2108-
@settings(deadline=None)
2109-
@given(
2110-
M=st.sampled_from([1, 4]),
2111-
N=st.sampled_from([1024, 6144]),
2112-
K=st.sampled_from([512, 3584]),
2113-
CudaGraph=st.sampled_from([True, False]),
2114-
)
2115-
def test_fp8_lite_matmul(self, M: int, N: int, K: int, CudaGraph: bool) -> None:
2116-
x = (
2117-
torch.randn(
2118-
size=(M, K),
2119-
dtype=torch.bfloat16,
2120-
device=self.device,
2121-
)
2122-
* 0.1
2123-
)
2124-
w = (
2125-
torch.randn(
2126-
size=(N, K),
2127-
dtype=torch.bfloat16,
2128-
device=self.device,
2129-
)
2130-
* 0.01
2131-
)
2132-
xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(x)
2133-
wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(w)
2134-
if CudaGraph:
2135-
zq = torch.ops.fbgemm.f8f8bf16_lite(xq, wq, x_scale * w_scale)
2136-
g = torch.cuda.CUDAGraph()
2137-
with torch.cuda.graph(g):
2138-
zq = torch.ops.fbgemm.f8f8bf16_lite(xq, wq, x_scale * w_scale)
2139-
g.replay()
2140-
else:
2141-
zq = torch.ops.fbgemm.f8f8bf16_lite(xq, wq, x_scale * w_scale)
2142-
zq_ref = (x @ w.T).to(torch.bfloat16)
2143-
torch.testing.assert_close(zq, zq_ref, atol=9.0e-2, rtol=9.0e-2)
2144-
21452145

21462146
@unittest.skipIf(
21472147
not torch.cuda.is_available() or torch.version.hip,

0 commit comments

Comments
 (0)