Skip to content

Commit 056ddb0

Browse files
authored
Use get_fp8_constants from fp8_utils.py instead of fbgemm_gpu (#444)
1 parent 0063982 commit 056ddb0

File tree

4 files changed

+56
-13
lines changed

4 files changed

+56
-13
lines changed

tritonbench/operators/fp8_gemm_rowwise/aoti_fp8_triton_mm.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,10 @@
44
import torch
55
import triton
66
import triton.language as tl
7-
8-
from fbgemm_gpu.experimental.gemm.triton_gemm.fp8_gemm import (
9-
get_fp8_constants as get_fp8_constants,
10-
)
117
from triton import Config
128

9+
from tritonbench.utils.fp8_utils import get_fp8_constants
10+
1311
FP8_DTYPE, _, _, _ = get_fp8_constants()
1412
E4M3_MAX_POS: float = torch.finfo(FP8_DTYPE).max
1513
EPS: float = 1e-12

tritonbench/operators/fp8_gemm_rowwise/operator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,10 @@ def parse_args(args: List[str]) -> argparse.Namespace:
6969
HAS_CUTLASS_OR_CK = False
7070
HAS_CUBLAS = False
7171

72+
from tritonbench.utils.fp8_utils import get_fp8_constants
73+
7274
try:
7375
from fbgemm_gpu.experimental.gemm.triton_gemm.fp8_gemm import (
74-
get_fp8_constants as get_fp8_constants,
7576
matmul_fp8_row as triton_fp8_row,
7677
)
7778

tritonbench/operators/fp8_gemm_rowwise_grouped/operator.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -157,14 +157,7 @@ def parse_args(args: List[str]) -> argparse.Namespace:
157157
HAS_TRITON = False
158158
HAS_CUTLASS_OR_CK = False
159159

160-
# Try to import Triton GEMM module
161-
try:
162-
from fbgemm_gpu.experimental.gemm.triton_gemm.fp8_gemm import (
163-
get_fp8_constants as get_fp8_constants,
164-
)
165-
except (ImportError, AssertionError):
166-
# If import fails, set HAS_TRITON to False
167-
HAS_TRITON = False
160+
from tritonbench.utils.fp8_utils import get_fp8_constants
168161

169162
# Try to import Triton grouped GEMM module
170163
try:

tritonbench/utils/fp8_utils.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
"""FP8 utilities for tritonbench operators."""
2+
3+
import functools
4+
import os
5+
from typing import Tuple
6+
7+
import torch
8+
import triton.language as tl
9+
10+
11+
@functools.lru_cache
12+
def supports_float8_fnuz(throw_on_hip_incompatibility: bool = True) -> bool:
13+
if torch.version.hip:
14+
device_capability = torch.cuda.get_device_capability()
15+
16+
if device_capability < (9, 4):
17+
gpu_arch = torch.cuda.get_device_properties("cuda").gcnArchName
18+
msg = f"Unsupported GPU arch: {gpu_arch} for FP8"
19+
if throw_on_hip_incompatibility:
20+
raise RuntimeError(msg)
21+
else:
22+
import logging
23+
24+
logging.error(msg)
25+
return False
26+
27+
elif device_capability == (9, 4):
28+
return True
29+
30+
return False
31+
32+
33+
def get_fp8_constants() -> Tuple[torch.dtype, tl.dtype, float, float]:
34+
"""
35+
Helper function to get constant values for the current platform.
36+
37+
Returns:
38+
pt_dtype (torch.dtype): The correct torch fp8 datatype.
39+
tl_dtype (tl.dtype): The correct triton fp8 datatype.
40+
max_fp8 (float): The maximum reprsentable value for the fp8 datatype.
41+
eps (float): Minimum clip value to prevent divide by zero.
42+
"""
43+
running_on_github: bool = os.getenv("GITHUB_ENV") is not None
44+
if supports_float8_fnuz(throw_on_hip_incompatibility=(not running_on_github)):
45+
pt_fp8_dtype = torch.float8_e4m3fnuz
46+
tl_fp8_dtype = tl.float8e4b8
47+
else:
48+
pt_fp8_dtype = torch.float8_e4m3fn
49+
tl_fp8_dtype = tl.float8e4nv
50+
51+
return pt_fp8_dtype, tl_fp8_dtype, torch.finfo(pt_fp8_dtype).max, 1e-12

0 commit comments

Comments
 (0)