Skip to content

Commit 20a8ac9

Browse files
authored
[AMD] Fix downcast to float8e5 software emulation (#7361)
Fix downcast to float8e5 software emulation. This PR re-writes the following software conversions: - FP16 -> FP8E5 RTNE - FP32 -> FP8E5 RTNE - BF16 -> FP8E5 RTNE It adds additional steps for conversion to handle subnormal values and saturation mode.
1 parent d0abc51 commit 20a8ac9

File tree

2 files changed

+205
-118
lines changed

2 files changed

+205
-118
lines changed

python/test/unit/language/test_conversions.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -341,9 +341,6 @@ def test_typeconvert_downcast(src_dtype, dst_dtype, rounding, max_repr, device):
341341
pytest.skip(f"{dst_dtype} downcast with RTNE rounding tests only supported on AMDGPU CDNA3")
342342

343343
if is_hip():
344-
if dst_dtype == 'float8e5' and rounding == 'rtne' and not is_hip_cdna4():
345-
pytest.skip(f"{dst_dtype} downcast with RTNE rounding tests only supported on CDNA4")
346-
347344
if dst_dtype == 'float8e4nv':
348345
if not rounding == 'rtne':
349346
pytest.skip("float8e4nv downcast tests only supported with RTNE rounding on AMDGPU")

third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp

Lines changed: 205 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -108,40 +108,48 @@ cvtScalePkDowncastToFp8(Location loc, ConversionPatternRewriter &rewriter,
108108

109109
// Fp16 -> OCP Bf8 (RTNE)
110110

111-
// FP8E5M2 is the open-compute standard FP8E5M2 format. NVIDIA GPU supports it
112-
// natively but we don't have hardware native support on CDNA3.
113-
//
114-
// The SW based downcast with RTNE is not fully functional for the denorm
115-
// values. We need rewrite it if we need to emulate this data type on AMDGPU.
116111
static SmallVector<Value>
117112
Fp16_to_Fp8E5M2_RTNE_SW(Location loc, ConversionPatternRewriter &rewriter,
118113
const SmallVector<Value> &v) {
114+
assert(v.size() == 4);
119115
auto b = TritonLLVMOpBuilder(loc, rewriter);
120-
auto fp16x2VecTy = vec_ty(f16_ty, 2);
121-
Value fp16x2Vec0 = b.undef(fp16x2VecTy);
122-
Value fp16x2Vec1 = b.undef(fp16x2VecTy);
123-
fp16x2Vec0 = b.insert_element(fp16x2VecTy, fp16x2Vec0, v[0], b.i32_val(0));
124-
fp16x2Vec0 = b.insert_element(fp16x2VecTy, fp16x2Vec0, v[1], b.i32_val(1));
125-
fp16x2Vec1 = b.insert_element(fp16x2VecTy, fp16x2Vec1, v[2], b.i32_val(0));
126-
fp16x2Vec1 = b.insert_element(fp16x2VecTy, fp16x2Vec1, v[3], b.i32_val(1));
127-
128-
Value a0 = b.bitcast(fp16x2Vec0, i32_ty);
129-
Value a1 = b.bitcast(fp16x2Vec1, i32_ty);
130116

131-
a0 = b.and_(i32_ty, a0, b.i32_val(0xfffefffe));
132-
a1 = b.and_(i32_ty, a1, b.i32_val(0xfffefffe));
133-
134-
a0 = b.add(i32_ty, a0, b.i32_val(0x00800080));
135-
a1 = b.add(i32_ty, a1, b.i32_val(0x00800080));
136-
137-
auto fp8x4VecTy = vec_ty(i8_ty, 4);
138-
a0 = b.bitcast(a0, fp8x4VecTy);
139-
a1 = b.bitcast(a1, fp8x4VecTy);
117+
SmallVector<Value> result(4);
118+
for (size_t i = 0; i < 4; ++i) {
119+
Value fp16 = v[i];
120+
Value i16 = b.bitcast(fp16, i16_ty);
121+
122+
Value s = b.and_(i16_ty, i16, b.i16_val(0x8000));
123+
Value exp =
124+
b.and_(i16_ty, b.lshr(i16_ty, i16, b.i16_val(10)), b.i16_val(0x1F));
125+
Value man = b.and_(i16_ty, i16, b.i16_val(0x03FF));
126+
Value sig = b.and_(i16_ty, i16, b.i16_val(0x7FFF));
127+
128+
// Round 10-bit mantissa to 2-bit nearest, ties to even
129+
Value bias = b.add(
130+
i16_ty,
131+
b.lshr(i16_ty, b.and_(i16_ty, sig, b.i16_val(0x0100)), b.i16_val(8)),
132+
b.i16_val(0x007F));
133+
i16 = b.add(i16_ty, sig, bias);
134+
135+
// Handle overflow using saturation mode, by setting sig to be the max.
136+
// Any number equal or larger than 0x7B80 after rounding (including
137+
// infinite 0x7C00) will cause overflow
138+
i16 = b.select(b.icmp_uge(sig, b.i16_val(0x7B80)), b.i16_val(0x7B00), i16);
139+
140+
// Handle NaN value by keeping it Nan
141+
i16 = b.select(
142+
b.and_(b.icmp_eq(exp, b.i16_val(0x1F)), b.icmp_ne(man, b.i16_val(0x0))),
143+
b.i16_val(0x7E00), i16);
144+
145+
// Add sign bit
146+
i16 = b.or_(i16_ty, s, i16);
147+
148+
// Truncate to 8-bit
149+
result[i] = b.trunc(i8_ty, b.lshr(i16_ty, i16, b.i16_val(8)));
150+
}
140151

141-
return {b.extract_element(i8_ty, a0, b.i32_val(1)),
142-
b.extract_element(i8_ty, a0, b.i32_val(3)),
143-
b.extract_element(i8_ty, a1, b.i32_val(1)),
144-
b.extract_element(i8_ty, a1, b.i32_val(3))};
152+
return result;
145153
}
146154

147155
static SmallVector<Value>
@@ -377,15 +385,98 @@ static SmallVector<Value> Fp32_to_Fp8E4M3FN(Location loc,
377385
v[0], v[1]);
378386
}
379387

380-
// Convert Fp32 to OCP Bf8 on CDNA4
388+
// Fp32 -> OCP Bf8 (RTNE)
389+
381390
static SmallVector<Value>
382-
Fp32_to_Fp8E5M2_RTNE(Location loc, ConversionPatternRewriter &rewriter,
383-
const SmallVector<Value> &v) {
391+
Fp32_to_Fp8E5M2_RTNE_SW(Location loc, ConversionPatternRewriter &rewriter,
392+
const SmallVector<Value> &v) {
393+
assert(v.size() == 4);
394+
auto b = TritonLLVMOpBuilder(loc, rewriter);
395+
396+
SmallVector<Value> result(4);
397+
for (size_t i = 0; i < 4; ++i) {
398+
Value fp32 = v[i];
399+
Value i32 = b.bitcast(fp32, i32_ty);
400+
401+
Value s = b.and_(i32_ty, i32, b.i32_val(0x80000000));
402+
Value exp =
403+
b.and_(i32_ty, b.lshr(i32_ty, i32, b.i32_val(23)), b.i32_val(0xFF));
404+
Value man = b.and_(i32_ty, i32, b.i32_val(0x007FFFFF));
405+
406+
// Convert 8-bit exponent to 5-bit
407+
Value exp5 = b.select(b.icmp_ult(exp, b.i32_val(0x71)), b.i32_val(0),
408+
b.sub(i32_ty, exp, b.i32_val(0x70)));
409+
410+
// Handle subnormal values (exp5 = 0)
411+
// - exp < 0x6e: mantissa = 0x00000000 (0)
412+
// - exp == 0x6e: mantissa = 0x00000000 (0),
413+
// 0x00200000 (1/4)
414+
// - exp == 0x6f: mantissa = 0x00200000 (1/4),
415+
// 0x00400000 (1/2)
416+
// - exp == 0x70: mantissa = 0x00400000 (1/2),
417+
// 0x00600000 (3/4),
418+
// 0x00800000 (1)
419+
man = b.select(b.icmp_ult(exp, b.i32_val(0x6e)), b.i32_val(0), man);
420+
man = b.select(b.icmp_eq(exp, b.i32_val(0x6e)),
421+
b.select(b.icmp_ne(man, b.i32_val(0)), b.i32_val(0x00200000),
422+
b.i32_val(0)),
423+
man);
424+
man = b.select(b.icmp_eq(exp, b.i32_val(0x6f)),
425+
b.select(b.icmp_uge(man, b.i32_val(0x00400000)),
426+
b.i32_val(0x00400000), b.i32_val(0x00200000)),
427+
man);
428+
man = b.select(
429+
b.icmp_eq(exp, b.i32_val(0x70)),
430+
b.select(b.icmp_ugt(man, b.i32_val(0x00200000)),
431+
b.select(b.icmp_uge(man, b.i32_val(0x00600000)),
432+
b.i32_val(0x00800000), b.i32_val(0x00600000)),
433+
b.i32_val(0x00400000)),
434+
man);
435+
436+
// Round 23-bit mantissa to 2-bit nearest, ties to even
437+
Value sig = b.or_(i32_ty, b.shl(i32_ty, exp5, b.i32_val(23)), man);
438+
Value bias =
439+
b.add(i32_ty,
440+
b.lshr(i32_ty, b.and_(i32_ty, sig, b.i32_val(0x00200000)),
441+
b.i32_val(21)),
442+
b.i32_val(0x000FFFFF));
443+
i32 = b.add(i32_ty, sig, bias);
444+
445+
// Handle overflow using saturation mode, by setting sig to be the max.
446+
// Overflow will happe for the following cases:
447+
// - Any number equal or larger than 0x0F700000 after rounding
448+
// - Exponent larged than 0x8E (including infinite 0xFF)
449+
i32 = b.select(b.or_(b.icmp_ugt(exp, b.i32_val(0x8E)),
450+
b.icmp_uge(sig, b.i32_val(0x0F700000))),
451+
b.i32_val(0x0F7FFFFF), i32);
452+
453+
// Handle NaN value by keeping it Nan
454+
i32 = b.select(
455+
b.and_(b.icmp_eq(exp, b.i32_val(0xFF)), b.icmp_ne(man, b.i32_val(0x0))),
456+
b.i32_val(0x0FC00000), i32);
457+
458+
// Add sign bit
459+
i32 = b.or_(i32_ty, b.lshr(i32_ty, s, b.i32_val(3)), i32);
460+
461+
// Truncate to 8-bit
462+
result[i] = b.trunc(i8_ty, b.lshr(i32_ty, i32, b.i32_val(21)));
463+
}
464+
return result;
465+
}
466+
467+
static SmallVector<Value>
468+
Fp32_to_Fp8E5M2_RTNE_HW(Location loc, ConversionPatternRewriter &rewriter,
469+
const SmallVector<Value> &v) {
384470
assert(v.size() == 2);
385471
return cvtScalePkDowncastToFp8<ROCDL::CvtScaleF32PkBf8F32Op>(loc, rewriter,
386472
v[0], v[1]);
387473
}
388474

475+
ConverterT Fp32_to_Fp8E5M2_RTNE(AMD::ISAFamily isaFamily) {
476+
return isaFamily == AMD::ISAFamily::CDNA4 ? Fp32_to_Fp8E5M2_RTNE_HW
477+
: Fp32_to_Fp8E5M2_RTNE_SW;
478+
}
479+
389480
// Fp32 -> Nanoo Bf8 on CDNA3
390481
static SmallVector<Value>
391482
Fp32_to_Fp8E5M2FNUZ(Location loc, ConversionPatternRewriter &rewriter,
@@ -853,86 +944,77 @@ ConverterT Fp8E5M2_to_Bf16(AMD::ISAFamily isaFamily) {
853944
static SmallVector<Value>
854945
Bf16_to_Fp8E5M2_SW(Location loc, ConversionPatternRewriter &rewriter,
855946
const SmallVector<Value> &v) {
947+
assert(v.size() == 4);
856948
auto b = TritonLLVMOpBuilder(loc, rewriter);
857-
auto bf16x2VecTy = vec_ty(bf16_ty, 2);
858-
Value bf16x2Vec0 = b.undef(bf16x2VecTy);
859-
Value bf16x2Vec1 = b.undef(bf16x2VecTy);
860-
bf16x2Vec0 = b.insert_element(bf16x2VecTy, bf16x2Vec0, v[0], b.i32_val(0));
861-
bf16x2Vec0 = b.insert_element(bf16x2VecTy, bf16x2Vec0, v[1], b.i32_val(1));
862-
bf16x2Vec1 = b.insert_element(bf16x2VecTy, bf16x2Vec1, v[2], b.i32_val(0));
863-
bf16x2Vec1 = b.insert_element(bf16x2VecTy, bf16x2Vec1, v[3], b.i32_val(1));
864-
bf16x2Vec0 = b.bitcast(bf16x2Vec0, i32_ty);
865-
bf16x2Vec1 = b.bitcast(bf16x2Vec1, i32_ty);
866-
867-
Value sign0 = b.and_(i32_ty, bf16x2Vec0, b.i32_val(0x80008000));
868-
Value sign1 = b.and_(i32_ty, bf16x2Vec1, b.i32_val(0x80008000));
869-
auto fp8x4VecTy = vec_ty(i8_ty, 4);
870-
Value sign = b.undef(fp8x4VecTy);
871-
sign0 = b.bitcast(sign0, fp8x4VecTy);
872-
sign1 = b.bitcast(sign1, fp8x4VecTy);
873-
sign = b.insert_element(fp8x4VecTy, sign,
874-
b.extract_element(i8_ty, sign0, b.i32_val(1)),
875-
b.i32_val(0));
876-
sign = b.insert_element(fp8x4VecTy, sign,
877-
b.extract_element(i8_ty, sign0, b.i32_val(3)),
878-
b.i32_val(1));
879-
sign = b.insert_element(fp8x4VecTy, sign,
880-
b.extract_element(i8_ty, sign1, b.i32_val(1)),
881-
b.i32_val(2));
882-
sign = b.insert_element(fp8x4VecTy, sign,
883-
b.extract_element(i8_ty, sign1, b.i32_val(3)),
884-
b.i32_val(3));
885-
sign = b.bitcast(sign, i32_ty);
886-
887-
Value nosign0 = b.and_(i32_ty, bf16x2Vec0, b.i32_val(0x7fff7fff));
888-
Value nosign1 = b.and_(i32_ty, bf16x2Vec1, b.i32_val(0x7fff7fff));
889-
890-
Value nosign_0_0 = b.and_(i32_ty, nosign0, b.i32_val(0xffff0000));
891-
nosign_0_0 = b.umax(i32_ty, nosign_0_0, b.i32_val(0x38000000));
892-
nosign_0_0 = b.umin(i32_ty, nosign_0_0, b.i32_val(0x57e00000));
893-
Value nosign_0_1 = b.and_(i32_ty, nosign0, b.i32_val(0x0000ffff));
894-
nosign_0_1 = b.umax(i32_ty, nosign_0_1, b.i32_val(0x3800));
895-
nosign_0_1 = b.umin(i32_ty, nosign_0_1, b.i32_val(0x57e0));
896-
nosign0 = b.or_(i32_ty, nosign_0_0, nosign_0_1);
897-
898-
Value nosign_1_0 = b.and_(i32_ty, nosign1, b.i32_val(0xffff0000));
899-
nosign_1_0 = b.umax(i32_ty, nosign_1_0, b.i32_val(0x38000000));
900-
nosign_1_0 = b.umin(i32_ty, nosign_1_0, b.i32_val(0x57e00000));
901-
Value nosign_1_1 = b.and_(i32_ty, nosign1, b.i32_val(0x0000ffff));
902-
nosign_1_1 = b.umax(i32_ty, nosign_1_1, b.i32_val(0x3800));
903-
nosign_1_1 = b.umin(i32_ty, nosign_1_1, b.i32_val(0x57e0));
904-
nosign1 = b.or_(i32_ty, nosign_1_0, nosign_1_1);
905-
906-
nosign0 = b.add(i32_ty, nosign0, b.i32_val(0x00100010));
907-
nosign1 = b.add(i32_ty, nosign1, b.i32_val(0x00100010));
908-
nosign0 = b.sub(i32_ty, nosign0, b.i32_val(0x38003800));
909-
nosign1 = b.sub(i32_ty, nosign1, b.i32_val(0x38003800));
910-
nosign0 = b.shl(i32_ty, nosign0, b.i32_val(3));
911-
nosign1 = b.shl(i32_ty, nosign1, b.i32_val(3));
912-
913-
nosign0 = b.bitcast(nosign0, fp8x4VecTy);
914-
nosign1 = b.bitcast(nosign1, fp8x4VecTy);
915-
Value nosign = b.undef(fp8x4VecTy);
916-
nosign = b.insert_element(fp8x4VecTy, nosign,
917-
b.extract_element(i8_ty, nosign0, b.i32_val(1)),
918-
b.i32_val(0));
919-
nosign = b.insert_element(fp8x4VecTy, nosign,
920-
b.extract_element(i8_ty, nosign0, b.i32_val(3)),
921-
b.i32_val(1));
922-
nosign = b.insert_element(fp8x4VecTy, nosign,
923-
b.extract_element(i8_ty, nosign1, b.i32_val(1)),
924-
b.i32_val(2));
925-
nosign = b.insert_element(fp8x4VecTy, nosign,
926-
b.extract_element(i8_ty, nosign1, b.i32_val(3)),
927-
b.i32_val(3));
928-
nosign = b.bitcast(nosign, i32_ty);
929-
930-
Value fp8x4Vec = b.or_(i32_ty, nosign, sign);
931-
fp8x4Vec = b.bitcast(fp8x4Vec, fp8x4VecTy);
932-
return {b.extract_element(i8_ty, fp8x4Vec, b.i32_val(0)),
933-
b.extract_element(i8_ty, fp8x4Vec, b.i32_val(1)),
934-
b.extract_element(i8_ty, fp8x4Vec, b.i32_val(2)),
935-
b.extract_element(i8_ty, fp8x4Vec, b.i32_val(3))};
949+
950+
SmallVector<Value> result(4);
951+
for (size_t i = 0; i < 4; ++i) {
952+
Value fp16 = v[i];
953+
Value i16 = b.bitcast(fp16, i16_ty);
954+
955+
Value s = b.and_(i16_ty, i16, b.i16_val(0x8000));
956+
Value exp =
957+
b.and_(i16_ty, b.lshr(i16_ty, i16, b.i16_val(7)), b.i16_val(0xFF));
958+
Value man = b.and_(i16_ty, i16, b.i16_val(0x7F));
959+
960+
// Convert 8-bit exponent to 5-bit exponent
961+
Value exp5 = b.select(b.icmp_ult(exp, b.i16_val(0x71)), b.i16_val(0),
962+
b.sub(i16_ty, exp, b.i16_val(0x70)));
963+
964+
// Handle subnormal values (exp5 = 0)
965+
// - exp < 0x6e: mantissa = 0x0000 (0)
966+
// - exp == 0x6e: mantissa = 0x0000 (0),
967+
// 0x0020 (1/4)
968+
// - exp == 0x6f: mantissa = 0x0020 (1/4),
969+
// 0x0040 (1/2)
970+
// - exp == 0x70: mantissa = 0x0040 (1/2),
971+
// 0x0060 (3/4),
972+
// 0x0080 (1)
973+
man = b.select(b.icmp_ult(exp, b.i16_val(0x6e)), b.i16_val(0), man);
974+
man = b.select(
975+
b.icmp_eq(exp, b.i16_val(0x6e)),
976+
b.select(b.icmp_ne(man, b.i16_val(0)), b.i16_val(0x0020), b.i16_val(0)),
977+
man);
978+
man = b.select(b.icmp_eq(exp, b.i16_val(0x6f)),
979+
b.select(b.icmp_uge(man, b.i16_val(0x0040)),
980+
b.i16_val(0x0040), b.i16_val(0x0020)),
981+
man);
982+
man = b.select(b.icmp_eq(exp, b.i16_val(0x70)),
983+
b.select(b.icmp_ugt(man, b.i16_val(0x0020)),
984+
b.select(b.icmp_uge(man, b.i16_val(0x0060)),
985+
b.i16_val(0x0080), b.i16_val(0x0060)),
986+
b.i16_val(0x0040)),
987+
man);
988+
989+
// Round 7-bit mantissa to 2-bit
990+
Value sig = b.or_(i16_ty, b.shl(i16_ty, exp5, b.i16_val(7)), man);
991+
Value bias = b.add(
992+
i16_ty,
993+
b.lshr(i16_ty, b.and_(i16_ty, sig, b.i16_val(0x0020)), b.i16_val(5)),
994+
b.i16_val(0x000F));
995+
i16 = b.add(i16_ty, sig, bias);
996+
997+
// Handle overflow using saturation mode, by setting sig to be the max.
998+
// Overflow will happe for the following cases:
999+
// - Any number equal or larger than 0x0F70 after rounding
1000+
// - Exponent larged than 0x8E (including infinite 0xFF)
1001+
i16 = b.select(b.or_(b.icmp_ugt(exp, b.i16_val(0x8E)),
1002+
b.icmp_uge(sig, b.i16_val(0x0F70))),
1003+
b.i16_val(0x0F7F), i16);
1004+
1005+
// Handle NaN value by keeping it Nan
1006+
i16 = b.select(
1007+
b.and_(b.icmp_eq(exp, b.i16_val(0xFF)), b.icmp_ne(man, b.i16_val(0x0))),
1008+
b.i16_val(0x0FC0), i16);
1009+
1010+
// Add sign bit
1011+
i16 = b.or_(i16_ty, b.lshr(i16_ty, s, b.i16_val(3)), i16);
1012+
1013+
// Truncate to 8-bit
1014+
result[i] = b.trunc(i8_ty, b.lshr(i16_ty, i16, b.i16_val(5)));
1015+
}
1016+
1017+
return result;
9361018
}
9371019

9381020
static SmallVector<Value>
@@ -1262,7 +1344,8 @@ struct FpToFpOpConversion
12621344
{{F32TyID, F8E5M2FNUZTyID, RoundingMode::RTNE},
12631345
Fp32_to_Fp8E5M2FNUZ},
12641346
{{F32TyID, F8E4M3FNTyID, RoundingMode::RTNE}, Fp32_to_Fp8E4M3FN},
1265-
{{F32TyID, F8E5M2TyID, RoundingMode::RTNE}, Fp32_to_Fp8E5M2_RTNE},
1347+
{{F32TyID, F8E5M2TyID, RoundingMode::RTNE},
1348+
Fp32_to_Fp8E5M2_RTNE(isaFamily)},
12661349
{{F32TyID, F8E5M2TyID, RoundingMode::RTZ}, Fp32_to_Fp8E5M2_RTZ},
12671350
{{F8E4M3FNUZTyID, F32TyID, undefRounding}, Fp8E4M3FNUZ_to_Fp32},
12681351
{{F8E5M2FNUZTyID, F32TyID, undefRounding}, Fp8E5M2FNUZ_to_Fp32},
@@ -1318,16 +1401,23 @@ struct FpToFpOpConversion
13181401
numElements = 4;
13191402
}
13201403

1321-
// f32->fp8/bf8 with rtne, if neither nanoo fp8/bf8 on CDNA3 nor ocp fp8/bf8
1322-
// on CDNA4, is done in two steps: f32->fp16 with rtne and fp16->fp8/bf8
1323-
// with rtne
1404+
// fp32 -> fp8 with rtne can be done in two steps:
1405+
// - fp32 -> fp16 with rtne and
1406+
// - fp16 -> fp8 with rtne
1407+
// with the following exceptions:
1408+
// 1. fp32 -> ocp fp8/bf8 on CDNA4: has hardware support
1409+
// 2. fp32 -> nanoo fp8/bf8 on non-CDNA4: has hardware support
1410+
// 3. fp32 -> ocp bf8 on non-CDNA4: has software support
13241411
bool useFP16IntermediateSrc =
13251412
srcElementType.isF32() && !dstElementType.isF16() &&
13261413
roundingMode == RoundingMode::RTNE &&
13271414
!(isaFamily == AMD::ISAFamily::CDNA4 &&
13281415
(llvm::isa<Float8E4M3FNType, Float8E5M2Type>(dstElementType))) &&
13291416
!(isaFamily == AMD::ISAFamily::CDNA3 &&
1330-
(llvm::isa<Float8E4M3FNUZType, Float8E5M2FNUZType>(dstElementType)));
1417+
(llvm::isa<Float8E4M3FNUZType, Float8E5M2FNUZType>(
1418+
dstElementType))) &&
1419+
!(isaFamily != AMD::ISAFamily::CDNA4 &&
1420+
(llvm::isa<Float8E5M2Type>(dstElementType)));
13311421

13321422
// fp8/bf8->f32, if neither nanoo fp8/bf8 on CDNA3 nor ocp fp8/bf8 on CDNA4,
13331423
// is done in two steps: fp8/bf8->fp16 and fp16->fp32

0 commit comments

Comments
 (0)