Skip to content

Commit 0524662

Browse files
authored
[AMD] Optimize fp32 to bf16 rtne conversion (#5869)
This PR is a re-implementation of the PR #5633, which is for more efficient approach for the type conversion from fp32 to bf16 in the hip backend. It avoid using inline asm, so the problem (`tl.store` is lowered to `global_store_ushort` instead of `global_store_dword`) related to #5633 is gone. It uses the same number of VGPRs and SGPRs as the main branch, but uses less number of instructions.
1 parent 08d7f64 commit 0524662

File tree

1 file changed

+32
-26
lines changed

1 file changed

+32
-26
lines changed

third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp

Lines changed: 32 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,18 @@ Fp16_to_Fp8E5M2_RTZ(Location loc, ConversionPatternRewriter &rewriter,
8585
b.extract_element(i8_ty, a1, b.i32_val(3))};
8686
}
8787

88+
static Value checkIsNan(TritonLLVMOpBuilder &builder, Value v) {
89+
StringRef intrinsic = "llvm.is.fpclass";
90+
// bits 0 and 1 indicate signaling Nan and quiet Nan, respectively
91+
Location loc = builder.loc;
92+
OpBuilder &rewriter = *builder.builder;
93+
Value nanBits = builder.i32_val(3);
94+
95+
return LLVM::createLLVMIntrinsicCallOp(rewriter, loc, intrinsic, i1_ty,
96+
ValueRange{v, nanBits})
97+
->getResult(0);
98+
}
99+
88100
// Cast FP16 to FP8E4M3FN in saturation and round-to-nearest-even mode.
89101
// According to
90102
// https://www.opencompute.org/documents/ocp-8-bit-floating-point-specification-ofp8-revision-1-0-2023-12-01-pdf-1,
@@ -94,11 +106,7 @@ static Value
94106
Fp16_to_Fp8E4M3FN_RTNE_oneValue(Location loc,
95107
ConversionPatternRewriter &rewriter, Value v) {
96108
auto b = TritonLLVMOpBuilder(loc, rewriter);
97-
StringRef funcName = "llvm.is.fpclass";
98-
Value isNaN = LLVM::createLLVMIntrinsicCallOp(rewriter, loc, funcName, i1_ty,
99-
{v, b.i32_val(0x3)})
100-
->getResult(0);
101-
109+
Value isNaN = checkIsNan(b, v);
102110
// Get sign and absolute value
103111
Value vi16 = b.bitcast(v, i16_ty);
104112
Value sign =
@@ -441,27 +449,25 @@ static Value convertFp32ToBf16(Location loc,
441449
auto truncated = b.trunc(i16_ty, shifted);
442450
return b.bitcast(truncated, bf16_ty);
443451
}
444-
// Otherwise it is (rounding == RoundingMode::RTNE)
445-
auto as_uint32 = b.bitcast(v, i32_ty);
446-
auto check_exponent =
447-
b.and_(i32_ty, b.xor_(i32_ty, as_uint32, b.i32_val(0xffffffff)),
448-
b.i32_val(0x7f800000));
449-
auto exponent_not_all1s = b.icmp_ne(check_exponent, b.i32_val(0));
450-
auto exponent_all1s = b.icmp_eq(check_exponent, b.i32_val(0));
451-
auto rounded = b.add(
452-
i32_ty, b.i32_val(0x7fff),
453-
b.and_(i32_ty, b.lshr(i32_ty, as_uint32, b.i32_val(16)), b.i32_val(1)));
454-
rounded = b.add(i32_ty, rounded, as_uint32);
455-
auto res = b.select(exponent_not_all1s, rounded, as_uint32);
456-
457-
auto preserve_nan = b.and_(
458-
i1_ty, exponent_all1s,
459-
b.icmp_ne(b.and_(i32_ty, as_uint32, b.i32_val(0xffff)), b.i32_val(0)));
460-
auto nan = b.or_(i32_ty, as_uint32, b.i32_val(0x10000));
461-
res = b.select(preserve_nan, nan, res);
462-
463-
auto shifted = b.lshr(i32_ty, res, b.i32_val(16));
464-
auto truncated = b.trunc(i16_ty, shifted);
452+
453+
// This implementation is a faster version for fp32 to bf16 type conversion
454+
// It is from CK:
455+
// https://github.com/cgmillette/composable_kernel/commit/24e75bef6aa5
456+
// It uses less VGPR and less number of instructions compared to the
457+
// previous implementation
458+
Value isNan = checkIsNan(b, v);
459+
Value v16 = b.i32_val(16);
460+
Value tmp = b.and_(i32_ty, b.lshr(i32_ty, as_int32, v16), b.i32_val(1));
461+
462+
Value v7FFF = b.i32_val(0x7FFF);
463+
Value s1 = b.add(as_int32, tmp);
464+
Value s2 = b.add(s1, v7FFF);
465+
466+
Value vNan = b.i32_val(0x7FFF0000);
467+
Value res = b.select(isNan, vNan, s2);
468+
469+
Value shifted = b.lshr(i32_ty, res, v16);
470+
Value truncated = b.trunc(i16_ty, shifted);
465471
return b.bitcast(truncated, bf16_ty);
466472
}
467473

0 commit comments

Comments
 (0)