Skip to content

Commit fb93fc1

Browse files
authored
Improve warning about fp8 deprecated format (#6931)
1 parent 09dc298 commit fb93fc1

File tree

6 files changed

+12
-9
lines changed

6 files changed

+12
-9
lines changed

python/test/unit/language/test_compile_errors.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -384,10 +384,11 @@ def test_fp8_support(fresh_triton_cache, dtype):
384384

385385
@triton.jit
386386
def dtype_kernel(dtype: tl.constexpr):
387-
_ = tl.full((256, ), 0.0, dtype)
387+
a = tl.full((64, 64), 0.0, dtype)
388+
tl.dot(a, a)
388389

389390
if dtype in warning_dtypes:
390-
ctx = pytest.warns(UserWarning, match=r"fp8e4b15 is deprecated in this architecture")
391+
ctx = pytest.warns(UserWarning, match=r"the use of fp8e4b15 is deprecated on Hopper and later architectures")
391392
elif dtype in supported_dtypes:
392393
ctx = contextlib.nullcontext()
393394
else:

python/triton/language/core.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -570,8 +570,6 @@ def to_ir(self, builder: ir.builder) -> ir.type:
570570
if self.name not in builder.options.supported_fp8_dtypes:
571571
raise ValueError(f'type {self} not supported in this architecture. '
572572
f'The supported fp8 dtypes are {builder.options.supported_fp8_dtypes}')
573-
if self.name in builder.options.deprecated_fp8_dtypes:
574-
warn(f"{self.name} is deprecated in this architecture and will be removed in a future triton release")
575573

576574
if self.name == 'void':
577575
return builder.get_void_ty()

python/triton/language/semantic.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1567,6 +1567,10 @@ def dot(lhs: tl.tensor, rhs: tl.tensor, acc: tl.tensor, input_precision: Optiona
15671567
assert lhs.dtype == rhs.dtype, f"Both operands must be same dtype. Got {lhs.dtype} and {rhs.dtype}"
15681568

15691569
if lhs.dtype.is_fp8e4b15() or rhs.dtype.is_fp8e4b15():
1570+
if "fp8e4b15" in builder.options.deprecated_fp8_dot_operand_dtypes:
1571+
warnings.warn(
1572+
"the use of fp8e4b15 is deprecated on Hopper and later architectures and can cause significant slow down. It will be removed in a future triton release"
1573+
)
15701574
# We upcast because there's no fp8e4b15 type in MLIR
15711575
lhs = cast(lhs, tl.float16, builder)
15721576
rhs = cast(rhs, tl.float16, builder)

python/triton/runtime/interpreter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ class InterpreterOptions:
119119
sanitize_overflow: bool = True
120120
arch: str = None
121121
supported_fp8_dtypes: Tuple[str] = ("fp8e5", "fp8e5b16", "fp8e4nv", "fp8e4b8", "fp8e4b15")
122-
deprecated_fp8_dtypes: Tuple[str] = ()
122+
deprecated_fp8_dot_operand_dtypes: Tuple[str] = ()
123123
default_dot_input_precision: str = "tf32"
124124
allowed_dot_input_precisions: Tuple[str] = ("tf32", "tf32x3", "ieee")
125125
max_num_imprecise_acc_default: int = 0

third_party/amd/backend/compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ class HIPOptions:
3838
sanitize_overflow: bool = True
3939
arch: str = None
4040
supported_fp8_dtypes: Tuple[str] = ("fp8e5", )
41-
deprecated_fp8_dtypes: Tuple[str] = ()
41+
deprecated_fp8_dot_operand_dtypes: Tuple[str] = ()
4242
default_dot_input_precision: str = "ieee"
4343
allowed_dot_input_precisions: Tuple[str] = ("ieee", )
4444
enable_fp_fusion: bool = True

third_party/nvidia/backend/compiler.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ class CUDAOptions:
111111
launch_cooperative_grid: bool = False
112112
launch_pdl: bool = False
113113
supported_fp8_dtypes: Tuple[str] = ("fp8e5", "fp8e4b15")
114-
deprecated_fp8_dtypes: Tuple[str] = ()
114+
deprecated_fp8_dot_operand_dtypes: Tuple[str] = ()
115115
default_dot_input_precision: str = "tf32"
116116
allowed_dot_input_precisions: Tuple[str] = ("tf32", "tf32x3", "ieee")
117117
max_num_imprecise_acc_default: bool = None
@@ -166,9 +166,9 @@ def parse_options(self, opts) -> Any:
166166
supported_fp8_dtypes.add("fp8e4nv")
167167
args["supported_fp8_dtypes"] = tuple(sorted(supported_fp8_dtypes))
168168

169-
if "deprecated_fp8_dtypes" not in args:
169+
if "deprecated_fp8_dot_operand_dtypes" not in args:
170170
if capability >= 90:
171-
args["deprecated_fp8_dtypes"] = ("fp8e4b15", )
171+
args["deprecated_fp8_dot_operand_dtypes"] = ("fp8e4b15", )
172172

173173
if "enable_fp_fusion" not in args:
174174
args["enable_fp_fusion"] = knobs.language.default_fp_fusion

0 commit comments

Comments
 (0)