@@ -85,6 +85,18 @@ Fp16_to_Fp8E5M2_RTZ(Location loc, ConversionPatternRewriter &rewriter,
8585 b.extract_element (i8_ty, a1, b.i32_val (3 ))};
8686}
8787
88+ static Value checkIsNan (TritonLLVMOpBuilder &builder, Value v) {
89+ StringRef intrinsic = " llvm.is.fpclass" ;
90+ // bits 0 and 1 indicate signaling Nan and quiet Nan, respectively
91+ Location loc = builder.loc ;
92+ OpBuilder &rewriter = *builder.builder ;
93+ Value nanBits = builder.i32_val (3 );
94+
95+ return LLVM::createLLVMIntrinsicCallOp (rewriter, loc, intrinsic, i1_ty,
96+ ValueRange{v, nanBits})
97+ ->getResult (0 );
98+ }
99+
88100// Cast FP16 to FP8E4M3FN in saturation and round-to-nearest-even mode.
89101// According to
90102// https://www.opencompute.org/documents/ocp-8-bit-floating-point-specification-ofp8-revision-1-0-2023-12-01-pdf-1,
@@ -94,11 +106,7 @@ static Value
94106Fp16_to_Fp8E4M3FN_RTNE_oneValue (Location loc,
95107 ConversionPatternRewriter &rewriter, Value v) {
96108 auto b = TritonLLVMOpBuilder (loc, rewriter);
97- StringRef funcName = " llvm.is.fpclass" ;
98- Value isNaN = LLVM::createLLVMIntrinsicCallOp (rewriter, loc, funcName, i1_ty,
99- {v, b.i32_val (0x3 )})
100- ->getResult (0 );
101-
109+ Value isNaN = checkIsNan (b, v);
102110 // Get sign and absolute value
103111 Value vi16 = b.bitcast (v, i16_ty);
104112 Value sign =
@@ -441,27 +449,25 @@ static Value convertFp32ToBf16(Location loc,
441449 auto truncated = b.trunc (i16_ty, shifted);
442450 return b.bitcast (truncated, bf16_ty);
443451 }
444- // Otherwise it is (rounding == RoundingMode::RTNE)
445- auto as_uint32 = b.bitcast (v, i32_ty);
446- auto check_exponent =
447- b.and_ (i32_ty, b.xor_ (i32_ty, as_uint32, b.i32_val (0xffffffff )),
448- b.i32_val (0x7f800000 ));
449- auto exponent_not_all1s = b.icmp_ne (check_exponent, b.i32_val (0 ));
450- auto exponent_all1s = b.icmp_eq (check_exponent, b.i32_val (0 ));
451- auto rounded = b.add (
452- i32_ty, b.i32_val (0x7fff ),
453- b.and_ (i32_ty, b.lshr (i32_ty, as_uint32, b.i32_val (16 )), b.i32_val (1 )));
454- rounded = b.add (i32_ty, rounded, as_uint32);
455- auto res = b.select (exponent_not_all1s, rounded, as_uint32);
456-
457- auto preserve_nan = b.and_ (
458- i1_ty, exponent_all1s,
459- b.icmp_ne (b.and_ (i32_ty, as_uint32, b.i32_val (0xffff )), b.i32_val (0 )));
460- auto nan = b.or_ (i32_ty, as_uint32, b.i32_val (0x10000 ));
461- res = b.select (preserve_nan, nan, res);
462-
463- auto shifted = b.lshr (i32_ty, res, b.i32_val (16 ));
464- auto truncated = b.trunc (i16_ty, shifted);
452+
453+ // This implementation is a faster version for fp32 to bf16 type conversion
454+ // It is from CK:
455+ // https://github.com/cgmillette/composable_kernel/commit/24e75bef6aa5
456+ // It uses less VGPR and less number of instructions compared to the
457+ // previous implementation
458+ Value isNan = checkIsNan (b, v);
459+ Value v16 = b.i32_val (16 );
460+ Value tmp = b.and_ (i32_ty, b.lshr (i32_ty, as_int32, v16), b.i32_val (1 ));
461+
462+ Value v7FFF = b.i32_val (0x7FFF );
463+ Value s1 = b.add (as_int32, tmp);
464+ Value s2 = b.add (s1, v7FFF);
465+
466+ Value vNan = b.i32_val (0x7FFF0000 );
467+ Value res = b.select (isNan, vNan, s2);
468+
469+ Value shifted = b.lshr (i32_ty, res, v16);
470+ Value truncated = b.trunc (i16_ty, shifted);
465471 return b.bitcast (truncated, bf16_ty);
466472}
467473
0 commit comments