Skip to content

Commit 17364f3

Browse files
authored
[release/2.7] fp8 inductor tests: Add gfx120x support (#2229)
On gfx120x, triton supports float8_e5m2. Create f8_type_pair. For gfx942, add fnuz type, for gfx1200 add only float8_e5m2. For rest all archs use default fp8 type/ ocp. Signed-off-by: Jagadish Krishnamoorthy <[email protected]>
1 parent 5ebff96 commit 17364f3

File tree

1 file changed

+19
-6
lines changed

1 file changed

+19
-6
lines changed

test/inductor/test_fp8.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from torch.testing._internal.common_utils import (
1313
instantiate_parametrized_tests,
1414
parametrize,
15+
skipIfRocmArch,
1516
TEST_WITH_ROCM,
1617
)
1718
from torch.testing._internal.inductor_utils import HAS_CUDA
@@ -32,6 +33,17 @@
3233
FP16_MAX_POS: float = torch.finfo(torch.float16).max
3334
EPS: float = 1e-12
3435

36+
# fp8 data types for inductor based fp8 tests. This can be different
37+
# than the one used in eager mode.
38+
f8_type_pair = (torch.float8_e4m3fn, torch.float8_e5m2)
39+
if torch.version.hip:
40+
arch = torch.cuda.get_device_properties(0).gcnArchName
41+
if "gfx94" in arch:
42+
# for gfx942, use fnuz data type.
43+
f8_type_pair = (torch.float8_e4m3fnuz, torch.float8_e5m2fnuz)
44+
elif "gfx120" in arch:
45+
# for gfx1200 and gfx1201, e4m3 is not supported on triton.
46+
f8_type_pair = (torch.float8_e5m2,)
3547

3648
def _to_fp8_saturated(x: Tensor, float8_dtype: torch.dtype) -> Tensor:
3749
# The default behavior in PyTorch for casting to `float8_e4m3fn`
@@ -180,10 +192,11 @@ def fp8_matmul_unwrapped(x):
180192
x = torch.rand(*x_shape, device="cuda", dtype=dtype).to(e4m3_type)
181193
y_fp8 = compiled_fp8_matmul(x) # noqa: F841
182194

195+
@skipIfRocmArch(("gfx1200","gfx1201"))
183196
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
184197
@parametrize("dtype", (torch.float16, torch.bfloat16, torch.float))
185198
@parametrize("shape", ("15,3,13", "4,2048,4096"))
186-
@parametrize("dst_types", [(torch.float8_e4m3fn, torch.float8_e5m2)])
199+
@parametrize("dst_types", [f8_type_pair])
187200
def test_valid_cast(self, dtype: torch.dtype, shape: str, dst_types: tuple):
188201
dst_types = _fix_fp8_dtype_for_rocm(dst_types, device="cuda")
189202
e4m3, e5m2 = dst_types
@@ -227,7 +240,7 @@ def fp8_cast(x, dtype):
227240

228241
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
229242
@parametrize("src_dtype", (torch.float16, torch.bfloat16, torch.float))
230-
@parametrize("dst_dtype", (torch.float8_e4m3fn, torch.float8_e5m2))
243+
@parametrize("dst_dtype", f8_type_pair)
231244
@parametrize("shape", ("16,16,16", "4,2048,4096"))
232245
def test_to_fp8_saturated(
233246
self, src_dtype: torch.dtype, dst_dtype: torch.dtype, shape: str
@@ -249,7 +262,7 @@ def fp8_saturated(x, dtype):
249262

250263
@unittest.skipIf(TEST_WITH_ROCM, "ROCm fails with accuracy issue")
251264
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
252-
@parametrize("float8_dtype", (torch.float8_e4m3fn, torch.float8_e5m2))
265+
@parametrize("float8_dtype", f8_type_pair)
253266
@parametrize("shape", ("1,1,15", "1,10,15", "1,10,512", "1,10,4096", "4,2048,4096"))
254267
def test_amax_fp8_quant(self, float8_dtype: torch.dtype, shape: str):
255268
float8_dtype = _fix_fp8_dtype_for_rocm(float8_dtype, device="cuda")
@@ -274,7 +287,7 @@ def amax_fp8(x: Tensor, scale: Tensor):
274287
torch.testing.assert_close(y_compiled.half(), y.half(), rtol=1e-2, atol=1e-2)
275288

276289
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
277-
@parametrize("float8_dtype", (torch.float8_e4m3fn, torch.float8_e5m2))
290+
@parametrize("float8_dtype", f8_type_pair)
278291
@parametrize("shape", ("1,1,15", "1,10,15", "1,10,512", "1,10,4096", "4,2048,4096"))
279292
def test_amax_along_with_fp8_quant(self, float8_dtype: torch.dtype, shape: str):
280293
float8_dtype = _fix_fp8_dtype_for_rocm(float8_dtype, device="cuda")
@@ -305,7 +318,7 @@ def amax_fp8(x: Tensor, scale: Tensor, amax_buffer: Tensor):
305318

306319
@unittest.skipIf(TEST_WITH_ROCM, "ROCm fails with accuracy issue")
307320
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
308-
@parametrize("float8_dtype", (torch.float8_e4m3fn, torch.float8_e5m2))
321+
@parametrize("float8_dtype", f8_type_pair)
309322
@parametrize("amax_keep_dim", (True, False))
310323
@parametrize("shape", ("1,1,15", "1,10,15", "1,10,512", "1,10,4096", "4,2048,4096"))
311324
def test_layernorm_fp8_quant(
@@ -347,7 +360,7 @@ def ln_fp8(x: Tensor, scale: Tensor, amax_buffer: Tensor):
347360
)
348361

349362
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
350-
@parametrize("float8_dtype", (torch.float8_e4m3fn, torch.float8_e5m2))
363+
@parametrize("float8_dtype", f8_type_pair)
351364
@parametrize("shape", ("4,2048,4096",))
352365
@parametrize("keepdim", (False, True))
353366
def test_layernorm_fp8_quant_benchmark(

0 commit comments

Comments
 (0)