Skip to content

Commit ddacd46

Browse files
knwngravil-mobile
andauthored
[AMD] Clamp Results in Downcasting to FP8E4M3 and FP8E5M2 (#7337)
There are several conversion ops on the NV side using `satfinite` mode, but on the AMD side, some of those are in non-saturation mode. We need to align AMD ops with NV. For example, fp32 to OCP fp8 on mi350 is lowered to `ROCDL::CvtScaleF32PkFp8F32Op`, and is eventually lowered to `v_cvt_scalef32_pk_fp8_f32`, which, according to ISA, is in non-saturation mode. But on the NV side, it's lowered to `cvt.rn.satfinite.e4m3x2.f32`, which is in saturation mode. Other examples including: | Conversion | ROCDL dialect | Instruction | | ----------------- | ----------------------------- | -------------------------- | | fp32 to fp8e4m3fn | ROCDL::CvtScaleF32PkFp8F32Op | v_cvt_scalef32_pk_fp8_f32 | | fp32 to fp8e5m2 | ROCDL::CvtScaleF32PkBf8F32Op | v_cvt_scalef32_pk_bf8_f32 | | fp16 to fp8e4m3fn | ROCDL::CvtScaleF32PkFp8F16Op | v_cvt_scalef32_pk_fp8_f16 | | fp16 to fp8e5m2 | ROCDL::CvtScaleF32PkBf8F16Op | v_cvt_scalef32_pk_bf8_f16 | | bf16 to fp8e4m3fn | ROCDL::CvtScaleF32PkFp8Bf16Op | v_cvt_scalef32_pk_fp8_bf16 | | bf16 to fp8e5m2 | ROCDL::CvtScaleF32PkBf8Bf16Op | v_cvt_scalef32_pk_bf8_bf16 | This PR fixed this issue by enabling the `FP16_OVFL` flag in the Mode register before these conversion instrs. --------- Co-authored-by: ravil-mobile <[email protected]>
1 parent 677a30c commit ddacd46

File tree

3 files changed

+86
-3
lines changed

3 files changed

+86
-3
lines changed

python/test/unit/language/test_conversions.py

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import triton
88
import triton.language as tl
99

10-
from triton._internal_testing import is_cuda, is_hip, is_hip_cdna3, is_hip_cdna4
10+
from triton._internal_testing import is_cuda, is_hip, is_hip_cdna2, is_hip_cdna3, is_hip_cdna4
1111

1212

1313
def matching_int(dtype):
@@ -366,3 +366,74 @@ def test_typeconvert_downcast(src_dtype, dst_dtype, rounding, max_repr, device):
366366

367367
for i in range(256):
368368
downcast_test(getattr(tl, src_dtype), getattr(tl, dst_dtype), rounding, *stuff, max_repr, i, device=device)
369+
370+
@pytest.mark.parametrize("mode", [
371+
'max', 'min', 'inf', '-inf', 'nan',
372+
])
373+
@pytest.mark.parametrize("dst_dtype", ["float8e4nv", "float8e5"])
374+
@pytest.mark.parametrize("src_dtype", ["float32", "float16", "bfloat16"])
375+
def test_typeconvert_downcast_clamping(src_dtype, dst_dtype, mode, rounding="rtne", device="cuda"):
376+
if is_cuda():
377+
if src_dtype != 'float32' and torch.cuda.get_device_capability(0) < (9, 0):
378+
pytest.skip("non-float32 downcast tests only supported on NVGPU with compute capability 9.0+")
379+
380+
if dst_dtype in ('float8e5', 'float8e4nv') and rounding == 'rtne' and torch.cuda.get_device_capability(0) < (9, 0):
381+
pytest.skip(f"{dst_dtype} downcast with RTNE rounding tests only supported on NVGPU with compute capability 9.0+")
382+
elif is_hip():
383+
if is_hip_cdna2():
384+
pytest.skip(f"{dst_dtype} downcast to {dst_dtype} with clamping is not fully tested on AMDGPU CDNA2")
385+
386+
if is_hip_cdna3():
387+
if src_dtype == 'bfloat16' and dst_dtype == 'float8e4nv':
388+
pytest.skip(f"{src_dtype} downcast to {dst_dtype} with clamping is not fully tested on AMDGPU CDNA3")
389+
if dst_dtype == 'float8e5' and mode in ('inf', '-inf'):
390+
pytest.skip(f"Downcast to {dst_dtype} with clamping for `inf` or `-inf` "
391+
"is not fully tested on AMDGPU CDNA3")
392+
393+
converter = {
394+
tl.float8e4nv: torch.float8_e4m3fn,
395+
tl.float8e5: torch.float8_e5m2,
396+
tl.float16: torch.float16,
397+
tl.bfloat16: torch.bfloat16,
398+
tl.float32: torch.float32
399+
}
400+
401+
tl_src_dtype = getattr(tl, src_dtype)
402+
tl_dst_dtype = getattr(tl, dst_dtype)
403+
404+
torch_src_dtype = converter[tl_src_dtype]
405+
torch_dst_dtype = converter[tl_dst_dtype]
406+
407+
if mode in ('max', 'min'):
408+
# Added to input to exceed the representation range to produce NaN
409+
exceed_value = 100.0
410+
test_value = torch.finfo(torch_dst_dtype).max + exceed_value
411+
expected_result = torch.finfo(torch_dst_dtype).max
412+
elif mode in ('inf', '-inf'):
413+
test_value = torch.inf
414+
expected_result = torch.finfo(torch_dst_dtype).max
415+
else:
416+
assert mode == 'nan'
417+
test_value = torch.nan
418+
expected_result = torch.nan
419+
420+
if mode in ('min', '-inf'):
421+
test_value *= -1.0
422+
expected_result *= -1.0
423+
424+
BLOCK_SIZE = 1024
425+
shape = (BLOCK_SIZE * 2,)
426+
src = torch.full(shape, test_value, dtype=torch_src_dtype, device=device)
427+
dst = torch.empty(shape, dtype=torch_dst_dtype, device=device)
428+
429+
type_convert_triton[(src.shape[0] // BLOCK_SIZE,)](
430+
triton.reinterpret(src, torch_src_dtype),
431+
triton.reinterpret(dst, torch_dst_dtype),
432+
rounding,
433+
BLOCK_SIZE
434+
)
435+
436+
if mode == 'nan':
437+
assert(torch.all(torch.isnan(dst)))
438+
else:
439+
torch.testing.assert_close(dst, torch.full_like(dst, expected_result))

python/triton_kernels/tests/test_mxfp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,8 +146,8 @@ def test_mxfp_casting(
146146
if is_hip():
147147
if swizzle_value is not None or swizzle_scale is not None:
148148
pytest.skip("Other swizzling patterns are not supported by AMD GPU")
149-
if quant_dtype == 'float8_e4m3fn':
150-
pytest.skip("float8_e4m3fn cast hasn't been fully tested on AMD GPU")
149+
if quant_dtype == 'float8_e4m3fn' and is_hip_cdna3():
150+
pytest.skip("float8_e4m3fn cast hasn't been fully tested on AMD CDNA3")
151151
if quant_dtype == 'float8_e5m2' and is_hip_cdna3():
152152
pytest.skip("float8_e5m2 cast hasn't been fully tested on AMD CDNA3")
153153

third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,17 @@ static SmallVector<Value>
6565
cvtScalePkDowncastToFp8(Location loc, ConversionPatternRewriter &rewriter,
6666
Value v0, Value v1) {
6767
auto b = TritonLLVMOpBuilder(loc, rewriter);
68+
69+
// This is the location of the fp16_ovfl flag in the Mode register. It's
70+
// calculated following this formula:
71+
// (mode register ID = 1) | (Offset << 6) | ((Width - 1) << 11)
72+
// In this case, Offset = 23 and Width = 1.
73+
// When the bit is 0/1, the conversion from fp32/fp16/bf16 to fp8/bf8 is in
74+
// non-saturation/saturation mode.
75+
Value fp16OVFLModeRegLoc = b.i32_val(1473);
76+
LLVM::createLLVMIntrinsicCallOp(rewriter, loc, "llvm.amdgcn.s.setreg", {},
77+
{fp16OVFLModeRegLoc, b.i32_val(1)});
78+
6879
Type v2I16Ty = vec_ty(i16_ty, 2);
6980
Value v2I16Vec = b.undef(v2I16Ty);
7081
Value scale = b.f32_val(1);
@@ -84,6 +95,7 @@ cvtScalePkDowncastToFp8(Location loc, ConversionPatternRewriter &rewriter,
8495
result = rewriter.create<ConvertOp>(loc, v2I16Ty, v2I16Vec, srcVec, scale,
8596
/*dstLoHiSel=*/false);
8697
}
98+
8799
auto fp8x4VecTy = vec_ty(i8_ty, 4);
88100
auto fp8x4Vec = b.bitcast(result, fp8x4VecTy);
89101
SmallVector<Value> ret(2);

0 commit comments

Comments
 (0)