Skip to content

Commit 732c0db

Browse files
authored
[AMD] Support to Bf16->OCP Fp8 conversion on CDNA3 (#7469)
1 parent 4207ca4 commit 732c0db

File tree

2 files changed

+36
-18
lines changed

2 files changed

+36
-18
lines changed

python/test/unit/language/test_conversions.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import triton
88
import triton.language as tl
99

10-
from triton._internal_testing import is_cuda, is_hip, is_hip_cdna2, is_hip_cdna3, is_hip_cdna4
10+
from triton._internal_testing import is_cuda, is_hip, is_hip_cdna3, is_hip_cdna4
1111

1212

1313
def matching_int(dtype):
@@ -341,12 +341,6 @@ def test_typeconvert_downcast(src_dtype, dst_dtype, rounding, max_repr, device):
341341
pytest.skip(f"{dst_dtype} downcast with RTNE rounding tests only supported on AMDGPU CDNA3")
342342

343343
if is_hip():
344-
if dst_dtype == 'float8e4nv':
345-
if not rounding == 'rtne':
346-
pytest.skip("float8e4nv downcast tests only supported with RTNE rounding on AMDGPU")
347-
if not is_hip_cdna4() and src_dtype == 'bfloat16':
348-
pytest.skip("float8e4nv downcast tests from bfloat16 only supported on AMDGPU CDNA4")
349-
350344
if dst_dtype in ('float8e5b16', 'float8e4b8') and rounding == 'rtne' and not is_hip_cdna3():
351345
pytest.skip(f"{dst_dtype} downcast with RTNE rounding tests only supported on AMDGPU CDNA3")
352346

@@ -376,9 +370,6 @@ def test_typeconvert_downcast_clamping(src_dtype, dst_dtype, mode, device, round
376370

377371
if dst_dtype in ('float8e5', 'float8e4nv') and rounding == 'rtne' and torch.cuda.get_device_capability(0) < (9, 0):
378372
pytest.skip(f"{dst_dtype} downcast with RTNE rounding tests only supported on NVGPU with compute capability 9.0+")
379-
elif is_hip_cdna2() or is_hip_cdna3():
380-
if src_dtype == 'bfloat16' and dst_dtype == 'float8e4nv':
381-
pytest.skip(f"{src_dtype} downcast to {dst_dtype} with clamping is not fully tested on AMDGPU CDNA2/3")
382373

383374
converter = {
384375
tl.float8e4nv: torch.float8_e4m3fn,

third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,12 @@ template <typename FPType> struct FPTypeInfo {
6565
if (dstTyID == TypeID::get<Float8E5M2Type>())
6666
return {0x0080, 0x0180, 0x0200, 0x0380};
6767
}
68+
if constexpr (std::is_same_v<FPType, BFloat16Type>) {
69+
if (dstTyID == TypeID::get<Float8E4M3FNType>())
70+
return {0x3a80, 0x3b40, 0x3ba0, 0x3be0, 0x3c10, 0x3c30, 0x3c50, 0x3c70};
71+
if (dstTyID == TypeID::get<Float8E5M2Type>())
72+
return {0x3700, 0x37c0, 0x3820, 0x3860};
73+
}
6874
return {};
6975
}
7076

@@ -278,13 +284,16 @@ static Value Fp_to_Fp8E4M3FN_RTNE_oneValue(Location loc,
278284
ConversionPatternRewriter &rewriter,
279285
Value v) {
280286
static_assert((std::is_same_v<SrcFPType, Float32Type>) ||
281-
(std::is_same_v<SrcFPType, Float16Type>));
287+
(std::is_same_v<SrcFPType, Float16Type>) ||
288+
(std::is_same_v<SrcFPType, BFloat16Type>));
282289
auto b = TritonLLVMOpBuilder(loc, rewriter);
283290
const llvm::fltSemantics *srcSemantic = nullptr;
284291
if constexpr (std::is_same_v<SrcFPType, Float32Type>)
285292
srcSemantic = &llvm::APFloat::IEEEsingle();
286-
else
293+
else if constexpr (std::is_same_v<SrcFPType, Float16Type>)
287294
srcSemantic = &llvm::APFloat::IEEEhalf();
295+
else
296+
srcSemantic = &llvm::APFloat::BFloat();
288297
auto srcWidth = llvm::APFloat::getSizeInBits(*srcSemantic);
289298
auto srcMantissaBits = llvm::APFloat::semanticsPrecision(*srcSemantic) - 1;
290299
auto srcExponentBits = srcWidth - srcMantissaBits - 1;
@@ -365,8 +374,10 @@ static Value Fp_to_Fp8E4M3FN_RTNE_oneValue(Location loc,
365374
// number(including infinity) after rounding in FP8E4M3
366375
if constexpr (std::is_same_v<SrcFPType, Float32Type>)
367376
dstMaxOfSrcType |= 0x7ffff;
368-
else
377+
else if constexpr (std::is_same_v<SrcFPType, Float16Type>)
369378
dstMaxOfSrcType |= 0x7f;
379+
else
380+
dstMaxOfSrcType |= 0x7;
370381
Value isOverflowOrInf =
371382
b.icmp_ugt(intVal, srcFpInfo.toLLVMIntValue(dstMaxOfSrcType));
372383
vFp8 =
@@ -1168,15 +1179,30 @@ static ConverterT Bf16_to_Fp8E5M2(AMD::ISAFamily isaFamily) {
11681179
return isaFamily == AMD::ISAFamily::CDNA4 ? Bf16_to_Fp8E5M2_HW
11691180
: Bf16_to_Fp8E5M2_SW;
11701181
}
1171-
// Bf16 -> OCP Fp8
1172-
static SmallVector<Value> Bf16_to_Fp8E4M3FN(Location loc,
1173-
ConversionPatternRewriter &rewriter,
1174-
const SmallVector<Value> &v) {
1182+
1183+
// Bf16 -> OCP Fp8 using RTNE
1184+
static SmallVector<Value>
1185+
Bf16_to_Fp8E4M3FN_RTNE_SW(Location loc, ConversionPatternRewriter &rewriter,
1186+
const SmallVector<Value> &v) {
1187+
SmallVector<Value> result(2);
1188+
result[0] = Fp_to_Fp8E4M3FN_RTNE_oneValue<BFloat16Type>(loc, rewriter, v[0]);
1189+
result[1] = Fp_to_Fp8E4M3FN_RTNE_oneValue<BFloat16Type>(loc, rewriter, v[1]);
1190+
return result;
1191+
}
1192+
1193+
static SmallVector<Value>
1194+
Bf16_to_Fp8E4M3FN_RTNE_HW(Location loc, ConversionPatternRewriter &rewriter,
1195+
const SmallVector<Value> &v) {
11751196
assert(v.size() == 2);
11761197
return cvtScalePkDowncastToFp8<ROCDL::CvtScaleF32PkFp8Bf16Op>(loc, rewriter,
11771198
v[0], v[1]);
11781199
}
11791200

1201+
ConverterT Bf16_to_Fp8E4M3FN(AMD::ISAFamily isaFamily) {
1202+
return isaFamily == AMD::ISAFamily::CDNA4 ? Bf16_to_Fp8E4M3FN_RTNE_HW
1203+
: Bf16_to_Fp8E4M3FN_RTNE_SW;
1204+
}
1205+
11801206
// fp8e4m3fn to bf16
11811207
static SmallVector<Value>
11821208
Fp8E4M3FN_to_Bf16_SW(Location loc, ConversionPatternRewriter &rewriter,
@@ -1472,7 +1498,8 @@ struct FpToFpOpConversion
14721498
// BF16 -> F8
14731499
{{BF16TyID, F8E5M2TyID, RoundingMode::RTNE},
14741500
Bf16_to_Fp8E5M2(isaFamily)},
1475-
{{BF16TyID, F8E4M3FNTyID, RoundingMode::RTNE}, Bf16_to_Fp8E4M3FN},
1501+
{{BF16TyID, F8E4M3FNTyID, RoundingMode::RTNE},
1502+
Bf16_to_Fp8E4M3FN(isaFamily)},
14761503
{{BF16TyID, F8E5M2FNUZTyID, RoundingMode::RTNE},
14771504
Bf16_to_Fp8E5M2FNUZ},
14781505
{{BF16TyID, F8E4M3FNUZTyID, RoundingMode::RTNE},

0 commit comments

Comments
 (0)