@@ -108,40 +108,48 @@ cvtScalePkDowncastToFp8(Location loc, ConversionPatternRewriter &rewriter,
108108
109109// Fp16 -> OCP Bf8 (RTNE)
110110
111- // FP8E5M2 is the open-compute standard FP8E5M2 format. NVIDIA GPU supports it
112- // natively but we don't have hardware native support on CDNA3.
113- //
114- // The SW based downcast with RTNE is not fully functional for the denorm
115- // values. We need rewrite it if we need to emulate this data type on AMDGPU.
116111static SmallVector<Value>
117112Fp16_to_Fp8E5M2_RTNE_SW (Location loc, ConversionPatternRewriter &rewriter,
118113 const SmallVector<Value> &v) {
114+ assert (v.size () == 4 );
119115 auto b = TritonLLVMOpBuilder (loc, rewriter);
120- auto fp16x2VecTy = vec_ty (f16_ty, 2 );
121- Value fp16x2Vec0 = b.undef (fp16x2VecTy);
122- Value fp16x2Vec1 = b.undef (fp16x2VecTy);
123- fp16x2Vec0 = b.insert_element (fp16x2VecTy, fp16x2Vec0, v[0 ], b.i32_val (0 ));
124- fp16x2Vec0 = b.insert_element (fp16x2VecTy, fp16x2Vec0, v[1 ], b.i32_val (1 ));
125- fp16x2Vec1 = b.insert_element (fp16x2VecTy, fp16x2Vec1, v[2 ], b.i32_val (0 ));
126- fp16x2Vec1 = b.insert_element (fp16x2VecTy, fp16x2Vec1, v[3 ], b.i32_val (1 ));
127-
128- Value a0 = b.bitcast (fp16x2Vec0, i32_ty);
129- Value a1 = b.bitcast (fp16x2Vec1, i32_ty);
130116
131- a0 = b.and_ (i32_ty, a0, b.i32_val (0xfffefffe ));
132- a1 = b.and_ (i32_ty, a1, b.i32_val (0xfffefffe ));
133-
134- a0 = b.add (i32_ty, a0, b.i32_val (0x00800080 ));
135- a1 = b.add (i32_ty, a1, b.i32_val (0x00800080 ));
136-
137- auto fp8x4VecTy = vec_ty (i8_ty, 4 );
138- a0 = b.bitcast (a0, fp8x4VecTy);
139- a1 = b.bitcast (a1, fp8x4VecTy);
117+ SmallVector<Value> result (4 );
118+ for (size_t i = 0 ; i < 4 ; ++i) {
119+ Value fp16 = v[i];
120+ Value i16 = b.bitcast (fp16, i16_ty);
121+
122+ Value s = b.and_ (i16_ty, i16 , b.i16_val (0x8000 ));
123+ Value exp =
124+ b.and_ (i16_ty, b.lshr (i16_ty, i16 , b.i16_val (10 )), b.i16_val (0x1F ));
125+ Value man = b.and_ (i16_ty, i16 , b.i16_val (0x03FF ));
126+ Value sig = b.and_ (i16_ty, i16 , b.i16_val (0x7FFF ));
127+
128+ // Round 10-bit mantissa to 2-bit nearest, ties to even
129+ Value bias = b.add (
130+ i16_ty,
131+ b.lshr (i16_ty, b.and_ (i16_ty, sig, b.i16_val (0x0100 )), b.i16_val (8 )),
132+ b.i16_val (0x007F ));
133+ i16 = b.add (i16_ty, sig, bias);
134+
135+ // Handle overflow using saturation mode, by setting sig to be the max.
136+ // Any number equal or larger than 0x7B80 after rounding (including
137+ // infinite 0x7C00) will cause overflow
138+ i16 = b.select (b.icmp_uge (sig, b.i16_val (0x7B80 )), b.i16_val (0x7B00 ), i16 );
139+
140+ // Handle NaN value by keeping it Nan
141+ i16 = b.select (
142+ b.and_ (b.icmp_eq (exp, b.i16_val (0x1F )), b.icmp_ne (man, b.i16_val (0x0 ))),
143+ b.i16_val (0x7E00 ), i16 );
144+
145+ // Add sign bit
146+ i16 = b.or_ (i16_ty, s, i16 );
147+
148+ // Truncate to 8-bit
149+ result[i] = b.trunc (i8_ty, b.lshr (i16_ty, i16 , b.i16_val (8 )));
150+ }
140151
141- return {b.extract_element (i8_ty, a0, b.i32_val (1 )),
142- b.extract_element (i8_ty, a0, b.i32_val (3 )),
143- b.extract_element (i8_ty, a1, b.i32_val (1 )),
144- b.extract_element (i8_ty, a1, b.i32_val (3 ))};
152+ return result;
145153}
146154
147155static SmallVector<Value>
@@ -377,15 +385,98 @@ static SmallVector<Value> Fp32_to_Fp8E4M3FN(Location loc,
377385 v[0 ], v[1 ]);
378386}
379387
380- // Convert Fp32 to OCP Bf8 on CDNA4
388+ // Fp32 -> OCP Bf8 (RTNE)
389+
381390static SmallVector<Value>
382- Fp32_to_Fp8E5M2_RTNE (Location loc, ConversionPatternRewriter &rewriter,
383- const SmallVector<Value> &v) {
391+ Fp32_to_Fp8E5M2_RTNE_SW (Location loc, ConversionPatternRewriter &rewriter,
392+ const SmallVector<Value> &v) {
393+ assert (v.size () == 4 );
394+ auto b = TritonLLVMOpBuilder (loc, rewriter);
395+
396+ SmallVector<Value> result (4 );
397+ for (size_t i = 0 ; i < 4 ; ++i) {
398+ Value fp32 = v[i];
399+ Value i32 = b.bitcast (fp32, i32_ty);
400+
401+ Value s = b.and_ (i32_ty, i32 , b.i32_val (0x80000000 ));
402+ Value exp =
403+ b.and_ (i32_ty, b.lshr (i32_ty, i32 , b.i32_val (23 )), b.i32_val (0xFF ));
404+ Value man = b.and_ (i32_ty, i32 , b.i32_val (0x007FFFFF ));
405+
406+ // Convert 8-bit exponent to 5-bit
407+ Value exp5 = b.select (b.icmp_ult (exp, b.i32_val (0x71 )), b.i32_val (0 ),
408+ b.sub (i32_ty, exp, b.i32_val (0x70 )));
409+
410+ // Handle subnormal values (exp5 = 0)
411+ // - exp < 0x6e: mantissa = 0x00000000 (0)
412+ // - exp == 0x6e: mantissa = 0x00000000 (0),
413+ // 0x00200000 (1/4)
414+ // - exp == 0x6f: mantissa = 0x00200000 (1/4),
415+ // 0x00400000 (1/2)
416+ // - exp == 0x70: mantissa = 0x00400000 (1/2),
417+ // 0x00600000 (3/4),
418+ // 0x00800000 (1)
419+ man = b.select (b.icmp_ult (exp, b.i32_val (0x6e )), b.i32_val (0 ), man);
420+ man = b.select (b.icmp_eq (exp, b.i32_val (0x6e )),
421+ b.select (b.icmp_ne (man, b.i32_val (0 )), b.i32_val (0x00200000 ),
422+ b.i32_val (0 )),
423+ man);
424+ man = b.select (b.icmp_eq (exp, b.i32_val (0x6f )),
425+ b.select (b.icmp_uge (man, b.i32_val (0x00400000 )),
426+ b.i32_val (0x00400000 ), b.i32_val (0x00200000 )),
427+ man);
428+ man = b.select (
429+ b.icmp_eq (exp, b.i32_val (0x70 )),
430+ b.select (b.icmp_ugt (man, b.i32_val (0x00200000 )),
431+ b.select (b.icmp_uge (man, b.i32_val (0x00600000 )),
432+ b.i32_val (0x00800000 ), b.i32_val (0x00600000 )),
433+ b.i32_val (0x00400000 )),
434+ man);
435+
436+ // Round 23-bit mantissa to 2-bit nearest, ties to even
437+ Value sig = b.or_ (i32_ty, b.shl (i32_ty, exp5, b.i32_val (23 )), man);
438+ Value bias =
439+ b.add (i32_ty,
440+ b.lshr (i32_ty, b.and_ (i32_ty, sig, b.i32_val (0x00200000 )),
441+ b.i32_val (21 )),
442+ b.i32_val (0x000FFFFF ));
443+ i32 = b.add (i32_ty, sig, bias);
444+
445+ // Handle overflow using saturation mode, by setting sig to be the max.
446+ // Overflow will happe for the following cases:
447+ // - Any number equal or larger than 0x0F700000 after rounding
448+ // - Exponent larged than 0x8E (including infinite 0xFF)
449+ i32 = b.select (b.or_ (b.icmp_ugt (exp, b.i32_val (0x8E )),
450+ b.icmp_uge (sig, b.i32_val (0x0F700000 ))),
451+ b.i32_val (0x0F7FFFFF ), i32 );
452+
453+ // Handle NaN value by keeping it Nan
454+ i32 = b.select (
455+ b.and_ (b.icmp_eq (exp, b.i32_val (0xFF )), b.icmp_ne (man, b.i32_val (0x0 ))),
456+ b.i32_val (0x0FC00000 ), i32 );
457+
458+ // Add sign bit
459+ i32 = b.or_ (i32_ty, b.lshr (i32_ty, s, b.i32_val (3 )), i32 );
460+
461+ // Truncate to 8-bit
462+ result[i] = b.trunc (i8_ty, b.lshr (i32_ty, i32 , b.i32_val (21 )));
463+ }
464+ return result;
465+ }
466+
467+ static SmallVector<Value>
468+ Fp32_to_Fp8E5M2_RTNE_HW (Location loc, ConversionPatternRewriter &rewriter,
469+ const SmallVector<Value> &v) {
384470 assert (v.size () == 2 );
385471 return cvtScalePkDowncastToFp8<ROCDL::CvtScaleF32PkBf8F32Op>(loc, rewriter,
386472 v[0 ], v[1 ]);
387473}
388474
475+ ConverterT Fp32_to_Fp8E5M2_RTNE (AMD::ISAFamily isaFamily) {
476+ return isaFamily == AMD::ISAFamily::CDNA4 ? Fp32_to_Fp8E5M2_RTNE_HW
477+ : Fp32_to_Fp8E5M2_RTNE_SW;
478+ }
479+
389480// Fp32 -> Nanoo Bf8 on CDNA3
390481static SmallVector<Value>
391482Fp32_to_Fp8E5M2FNUZ (Location loc, ConversionPatternRewriter &rewriter,
@@ -853,86 +944,77 @@ ConverterT Fp8E5M2_to_Bf16(AMD::ISAFamily isaFamily) {
853944static SmallVector<Value>
854945Bf16_to_Fp8E5M2_SW (Location loc, ConversionPatternRewriter &rewriter,
855946 const SmallVector<Value> &v) {
947+ assert (v.size () == 4 );
856948 auto b = TritonLLVMOpBuilder (loc, rewriter);
857- auto bf16x2VecTy = vec_ty (bf16_ty, 2 );
858- Value bf16x2Vec0 = b.undef (bf16x2VecTy);
859- Value bf16x2Vec1 = b.undef (bf16x2VecTy);
860- bf16x2Vec0 = b.insert_element (bf16x2VecTy, bf16x2Vec0, v[0 ], b.i32_val (0 ));
861- bf16x2Vec0 = b.insert_element (bf16x2VecTy, bf16x2Vec0, v[1 ], b.i32_val (1 ));
862- bf16x2Vec1 = b.insert_element (bf16x2VecTy, bf16x2Vec1, v[2 ], b.i32_val (0 ));
863- bf16x2Vec1 = b.insert_element (bf16x2VecTy, bf16x2Vec1, v[3 ], b.i32_val (1 ));
864- bf16x2Vec0 = b.bitcast (bf16x2Vec0, i32_ty);
865- bf16x2Vec1 = b.bitcast (bf16x2Vec1, i32_ty);
866-
867- Value sign0 = b.and_ (i32_ty, bf16x2Vec0, b.i32_val (0x80008000 ));
868- Value sign1 = b.and_ (i32_ty, bf16x2Vec1, b.i32_val (0x80008000 ));
869- auto fp8x4VecTy = vec_ty (i8_ty, 4 );
870- Value sign = b.undef (fp8x4VecTy);
871- sign0 = b.bitcast (sign0, fp8x4VecTy);
872- sign1 = b.bitcast (sign1, fp8x4VecTy);
873- sign = b.insert_element (fp8x4VecTy, sign,
874- b.extract_element (i8_ty, sign0, b.i32_val (1 )),
875- b.i32_val (0 ));
876- sign = b.insert_element (fp8x4VecTy, sign,
877- b.extract_element (i8_ty, sign0, b.i32_val (3 )),
878- b.i32_val (1 ));
879- sign = b.insert_element (fp8x4VecTy, sign,
880- b.extract_element (i8_ty, sign1, b.i32_val (1 )),
881- b.i32_val (2 ));
882- sign = b.insert_element (fp8x4VecTy, sign,
883- b.extract_element (i8_ty, sign1, b.i32_val (3 )),
884- b.i32_val (3 ));
885- sign = b.bitcast (sign, i32_ty);
886-
887- Value nosign0 = b.and_ (i32_ty, bf16x2Vec0, b.i32_val (0x7fff7fff ));
888- Value nosign1 = b.and_ (i32_ty, bf16x2Vec1, b.i32_val (0x7fff7fff ));
889-
890- Value nosign_0_0 = b.and_ (i32_ty, nosign0, b.i32_val (0xffff0000 ));
891- nosign_0_0 = b.umax (i32_ty, nosign_0_0, b.i32_val (0x38000000 ));
892- nosign_0_0 = b.umin (i32_ty, nosign_0_0, b.i32_val (0x57e00000 ));
893- Value nosign_0_1 = b.and_ (i32_ty, nosign0, b.i32_val (0x0000ffff ));
894- nosign_0_1 = b.umax (i32_ty, nosign_0_1, b.i32_val (0x3800 ));
895- nosign_0_1 = b.umin (i32_ty, nosign_0_1, b.i32_val (0x57e0 ));
896- nosign0 = b.or_ (i32_ty, nosign_0_0, nosign_0_1);
897-
898- Value nosign_1_0 = b.and_ (i32_ty, nosign1, b.i32_val (0xffff0000 ));
899- nosign_1_0 = b.umax (i32_ty, nosign_1_0, b.i32_val (0x38000000 ));
900- nosign_1_0 = b.umin (i32_ty, nosign_1_0, b.i32_val (0x57e00000 ));
901- Value nosign_1_1 = b.and_ (i32_ty, nosign1, b.i32_val (0x0000ffff ));
902- nosign_1_1 = b.umax (i32_ty, nosign_1_1, b.i32_val (0x3800 ));
903- nosign_1_1 = b.umin (i32_ty, nosign_1_1, b.i32_val (0x57e0 ));
904- nosign1 = b.or_ (i32_ty, nosign_1_0, nosign_1_1);
905-
906- nosign0 = b.add (i32_ty, nosign0, b.i32_val (0x00100010 ));
907- nosign1 = b.add (i32_ty, nosign1, b.i32_val (0x00100010 ));
908- nosign0 = b.sub (i32_ty, nosign0, b.i32_val (0x38003800 ));
909- nosign1 = b.sub (i32_ty, nosign1, b.i32_val (0x38003800 ));
910- nosign0 = b.shl (i32_ty, nosign0, b.i32_val (3 ));
911- nosign1 = b.shl (i32_ty, nosign1, b.i32_val (3 ));
912-
913- nosign0 = b.bitcast (nosign0, fp8x4VecTy);
914- nosign1 = b.bitcast (nosign1, fp8x4VecTy);
915- Value nosign = b.undef (fp8x4VecTy);
916- nosign = b.insert_element (fp8x4VecTy, nosign,
917- b.extract_element (i8_ty, nosign0, b.i32_val (1 )),
918- b.i32_val (0 ));
919- nosign = b.insert_element (fp8x4VecTy, nosign,
920- b.extract_element (i8_ty, nosign0, b.i32_val (3 )),
921- b.i32_val (1 ));
922- nosign = b.insert_element (fp8x4VecTy, nosign,
923- b.extract_element (i8_ty, nosign1, b.i32_val (1 )),
924- b.i32_val (2 ));
925- nosign = b.insert_element (fp8x4VecTy, nosign,
926- b.extract_element (i8_ty, nosign1, b.i32_val (3 )),
927- b.i32_val (3 ));
928- nosign = b.bitcast (nosign, i32_ty);
929-
930- Value fp8x4Vec = b.or_ (i32_ty, nosign, sign);
931- fp8x4Vec = b.bitcast (fp8x4Vec, fp8x4VecTy);
932- return {b.extract_element (i8_ty, fp8x4Vec, b.i32_val (0 )),
933- b.extract_element (i8_ty, fp8x4Vec, b.i32_val (1 )),
934- b.extract_element (i8_ty, fp8x4Vec, b.i32_val (2 )),
935- b.extract_element (i8_ty, fp8x4Vec, b.i32_val (3 ))};
949+
950+ SmallVector<Value> result (4 );
951+ for (size_t i = 0 ; i < 4 ; ++i) {
952+ Value fp16 = v[i];
953+ Value i16 = b.bitcast (fp16, i16_ty);
954+
955+ Value s = b.and_ (i16_ty, i16 , b.i16_val (0x8000 ));
956+ Value exp =
957+ b.and_ (i16_ty, b.lshr (i16_ty, i16 , b.i16_val (7 )), b.i16_val (0xFF ));
958+ Value man = b.and_ (i16_ty, i16 , b.i16_val (0x7F ));
959+
960+ // Convert 8-bit exponent to 5-bit exponent
961+ Value exp5 = b.select (b.icmp_ult (exp, b.i16_val (0x71 )), b.i16_val (0 ),
962+ b.sub (i16_ty, exp, b.i16_val (0x70 )));
963+
964+ // Handle subnormal values (exp5 = 0)
965+ // - exp < 0x6e: mantissa = 0x0000 (0)
966+ // - exp == 0x6e: mantissa = 0x0000 (0),
967+ // 0x0020 (1/4)
968+ // - exp == 0x6f: mantissa = 0x0020 (1/4),
969+ // 0x0040 (1/2)
970+ // - exp == 0x70: mantissa = 0x0040 (1/2),
971+ // 0x0060 (3/4),
972+ // 0x0080 (1)
973+ man = b.select (b.icmp_ult (exp, b.i16_val (0x6e )), b.i16_val (0 ), man);
974+ man = b.select (
975+ b.icmp_eq (exp, b.i16_val (0x6e )),
976+ b.select (b.icmp_ne (man, b.i16_val (0 )), b.i16_val (0x0020 ), b.i16_val (0 )),
977+ man);
978+ man = b.select (b.icmp_eq (exp, b.i16_val (0x6f )),
979+ b.select (b.icmp_uge (man, b.i16_val (0x0040 )),
980+ b.i16_val (0x0040 ), b.i16_val (0x0020 )),
981+ man);
982+ man = b.select (b.icmp_eq (exp, b.i16_val (0x70 )),
983+ b.select (b.icmp_ugt (man, b.i16_val (0x0020 )),
984+ b.select (b.icmp_uge (man, b.i16_val (0x0060 )),
985+ b.i16_val (0x0080 ), b.i16_val (0x0060 )),
986+ b.i16_val (0x0040 )),
987+ man);
988+
989+ // Round 7-bit mantissa to 2-bit
990+ Value sig = b.or_ (i16_ty, b.shl (i16_ty, exp5, b.i16_val (7 )), man);
991+ Value bias = b.add (
992+ i16_ty,
993+ b.lshr (i16_ty, b.and_ (i16_ty, sig, b.i16_val (0x0020 )), b.i16_val (5 )),
994+ b.i16_val (0x000F ));
995+ i16 = b.add (i16_ty, sig, bias);
996+
997+ // Handle overflow using saturation mode, by setting sig to be the max.
998+ // Overflow will happe for the following cases:
999+ // - Any number equal or larger than 0x0F70 after rounding
1000+ // - Exponent larged than 0x8E (including infinite 0xFF)
1001+ i16 = b.select (b.or_ (b.icmp_ugt (exp, b.i16_val (0x8E )),
1002+ b.icmp_uge (sig, b.i16_val (0x0F70 ))),
1003+ b.i16_val (0x0F7F ), i16 );
1004+
1005+ // Handle NaN value by keeping it Nan
1006+ i16 = b.select (
1007+ b.and_ (b.icmp_eq (exp, b.i16_val (0xFF )), b.icmp_ne (man, b.i16_val (0x0 ))),
1008+ b.i16_val (0x0FC0 ), i16 );
1009+
1010+ // Add sign bit
1011+ i16 = b.or_ (i16_ty, b.lshr (i16_ty, s, b.i16_val (3 )), i16 );
1012+
1013+ // Truncate to 8-bit
1014+ result[i] = b.trunc (i8_ty, b.lshr (i16_ty, i16 , b.i16_val (5 )));
1015+ }
1016+
1017+ return result;
9361018}
9371019
9381020static SmallVector<Value>
@@ -1262,7 +1344,8 @@ struct FpToFpOpConversion
12621344 {{F32TyID, F8E5M2FNUZTyID, RoundingMode::RTNE},
12631345 Fp32_to_Fp8E5M2FNUZ},
12641346 {{F32TyID, F8E4M3FNTyID, RoundingMode::RTNE}, Fp32_to_Fp8E4M3FN},
1265- {{F32TyID, F8E5M2TyID, RoundingMode::RTNE}, Fp32_to_Fp8E5M2_RTNE},
1347+ {{F32TyID, F8E5M2TyID, RoundingMode::RTNE},
1348+ Fp32_to_Fp8E5M2_RTNE (isaFamily)},
12661349 {{F32TyID, F8E5M2TyID, RoundingMode::RTZ}, Fp32_to_Fp8E5M2_RTZ},
12671350 {{F8E4M3FNUZTyID, F32TyID, undefRounding}, Fp8E4M3FNUZ_to_Fp32},
12681351 {{F8E5M2FNUZTyID, F32TyID, undefRounding}, Fp8E5M2FNUZ_to_Fp32},
@@ -1318,16 +1401,23 @@ struct FpToFpOpConversion
13181401 numElements = 4 ;
13191402 }
13201403
1321- // f32->fp8/bf8 with rtne, if neither nanoo fp8/bf8 on CDNA3 nor ocp fp8/bf8
1322- // on CDNA4, is done in two steps: f32->fp16 with rtne and fp16->fp8/bf8
1323- // with rtne
1404+ // fp32 -> fp8 with rtne can be done in two steps:
1405+ // - fp32 -> fp16 with rtne and
1406+ // - fp16 -> fp8 with rtne
1407+ // with the following exceptions:
1408+ // 1. fp32 -> ocp fp8/bf8 on CDNA4: has hardware support
1409+ // 2. fp32 -> nanoo fp8/bf8 on non-CDNA4: has hardware support
1410+ // 3. fp32 -> ocp bf8 on non-CDNA4: has software support
13241411 bool useFP16IntermediateSrc =
13251412 srcElementType.isF32 () && !dstElementType.isF16 () &&
13261413 roundingMode == RoundingMode::RTNE &&
13271414 !(isaFamily == AMD::ISAFamily::CDNA4 &&
13281415 (llvm::isa<Float8E4M3FNType, Float8E5M2Type>(dstElementType))) &&
13291416 !(isaFamily == AMD::ISAFamily::CDNA3 &&
1330- (llvm::isa<Float8E4M3FNUZType, Float8E5M2FNUZType>(dstElementType)));
1417+ (llvm::isa<Float8E4M3FNUZType, Float8E5M2FNUZType>(
1418+ dstElementType))) &&
1419+ !(isaFamily != AMD::ISAFamily::CDNA4 &&
1420+ (llvm::isa<Float8E5M2Type>(dstElementType)));
13311421
13321422 // fp8/bf8->f32, if neither nanoo fp8/bf8 on CDNA3 nor ocp fp8/bf8 on CDNA4,
13331423 // is done in two steps: fp8/bf8->fp16 and fp16->fp32
0 commit comments