Skip to content

Commit 1c28e08

Browse files
authored
[AMD] Use more efficient fp32 to bf16 type conversion (triton-lang#5633)
This PR is to use a more efficient approach for the type conversion from fp32 to bf16 in the hip backend. According to a simple unit test: the number of VGPR used decreases from 18 to 10.
1 parent 0ecb172 commit 1c28e08

File tree

2 files changed

+72
-22
lines changed

2 files changed

+72
-22
lines changed
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
// RUN: triton-opt %s --convert-triton-amdgpu-to-llvm='arch=gfx942' | FileCheck %s
2+
3+
// CHECK-LABEL: llvm.func @fp32_to_bf16
4+
// CHECK: llvm.inline_asm {{.*}} "v_cmp_u_f32 $0, $1, $2", "=s,v,v"
5+
// CHECK: llvm.inline_asm {{.*}} "v_bfe_u32 $0, $1, $2, $3", "=v,v,v,v"
6+
// CHECK: llvm.inline_asm {{.*}} "v_add3_u32 $0, $1, $2, $3", "=v,v,v,v"
7+
// CHECK: llvm.inline_asm {{.*}} "v_cndmask_b32 $0, $1, $2, $3", "=v,v,v,s"
8+
9+
#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}>
10+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
11+
tt.func public @fp32_to_bf16(
12+
%arg: tensor<256xf32, #blocked>) {
13+
%8 = arith.truncf %arg : tensor<256xf32, #blocked> to tensor<256xbf16, #blocked>
14+
tt.return
15+
}
16+
}

third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp

Lines changed: 56 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -433,36 +433,70 @@ static Value convertBf16ToFp32(Location loc,
433433
return bitcast(shifted, f32_ty);
434434
}
435435

436+
static Value buildGCNInstruction(Location loc, RewriterBase &rewritter,
437+
StringRef instrName,
438+
ArrayRef<StringRef> constraints,
439+
ArrayRef<Value> vals, Type retType) {
440+
assert(constraints.size() == vals.size() + 1);
441+
assert(vals.size() == 2 || vals.size() == 3);
442+
GCNBuilder builder;
443+
GCNInstr &instr = *builder.create(instrName.str());
444+
GCNBuilder::Operand *out = builder.newOperand(constraints[0]);
445+
SmallVector<GCNBuilder::Operand *> operands;
446+
for (int i = 0; i < vals.size(); ++i) {
447+
operands.push_back(builder.newOperand(vals[i], constraints[i + 1]));
448+
}
449+
450+
if (vals.size() == 2) {
451+
instr(out, operands[0], operands[1]);
452+
} else {
453+
instr(out, operands[0], operands[1], operands[2]);
454+
}
455+
456+
return builder.launch(rewritter, loc, retType, false);
457+
}
458+
436459
static Value convertFp32ToBf16(Location loc,
437460
ConversionPatternRewriter &rewriter,
438461
const Value &v, const RoundingMode rounding) {
462+
auto as_int32 = bitcast(v, i32_ty);
439463
if (rounding == RoundingMode::RTZ) {
440-
auto as_int32 = bitcast(v, i32_ty);
441464
auto shifted = lshr(i32_ty, as_int32, i32_val(16));
442465
auto truncated = trunc(i16_ty, shifted);
443466
return bitcast(truncated, bf16_ty);
444467
}
445-
// Otherwise it is (rounding == RoundingMode::RTNE)
446-
auto as_uint32 = bitcast(v, i32_ty);
447-
auto check_exponent =
448-
and_(i32_ty, xor_(i32_ty, as_uint32, i32_val(0xffffffff)),
449-
i32_val(0x7f800000));
450-
auto exponent_not_all1s = icmp_ne(check_exponent, i32_val(0));
451-
auto exponent_all1s = icmp_eq(check_exponent, i32_val(0));
452-
auto rounded =
453-
add(i32_ty, i32_val(0x7fff),
454-
and_(i32_ty, lshr(i32_ty, as_uint32, i32_val(16)), i32_val(1)));
455-
rounded = add(i32_ty, rounded, as_uint32);
456-
auto res = select(exponent_not_all1s, rounded, as_uint32);
457-
458-
auto preserve_nan =
459-
and_(i1_ty, exponent_all1s,
460-
icmp_ne(and_(i32_ty, as_uint32, i32_val(0xffff)), i32_val(0)));
461-
auto nan = or_(i32_ty, as_uint32, i32_val(0x10000));
462-
res = select(preserve_nan, nan, res);
463-
464-
auto shifted = lshr(i32_ty, res, i32_val(16));
465-
auto truncated = trunc(i16_ty, shifted);
468+
469+
// This implementation is a faster version for fp32 to bf16 type conversion
470+
// It is from CK:
471+
// https://github.com/cgmillette/composable_kernel/commit/24e75bef6aa5
472+
// It uses less VGPR and less number of instructions compared to the
473+
// previous implementation
474+
SmallVector<StringRef> constraints0 = {"=s", "v", "v"};
475+
SmallVector<Value> vals0 = {v, v};
476+
Value isNan = buildGCNInstruction(loc, rewriter, "v_cmp_u_f32", constraints0,
477+
vals0, i64_ty);
478+
479+
Value v16 = i32_val(16);
480+
Value v1 = i32_val(1);
481+
SmallVector<StringRef> constraints1 = {"=v", "v", "v", "v"};
482+
SmallVector<Value> vals1 = {v, v16, v1};
483+
Value tmp = buildGCNInstruction(loc, rewriter, "v_bfe_u32", constraints1,
484+
vals1, i32_ty);
485+
486+
SmallVector<StringRef> constraints2 = {"=v", "v", "v", "v"};
487+
Value v7FFF = i32_val(0x7FFF);
488+
SmallVector<Value> vals2 = {v, tmp, v7FFF};
489+
Value tmp1 = buildGCNInstruction(loc, rewriter, "v_add3_u32", constraints2,
490+
vals2, i32_ty);
491+
492+
SmallVector<StringRef> constraints3 = {"=v", "v", "v", "s"};
493+
Value vNan = i32_val(0x7FFF0000);
494+
SmallVector<Value> vals3 = {tmp1, vNan, isNan};
495+
Value cndMask = buildGCNInstruction(loc, rewriter, "v_cndmask_b32",
496+
constraints3, vals3, i32_ty);
497+
498+
Value shifted = lshr(i32_ty, cndMask, v16);
499+
Value truncated = trunc(i16_ty, shifted);
466500
return bitcast(truncated, bf16_ty);
467501
}
468502

0 commit comments

Comments
 (0)