Skip to content

Commit bde92ef

Browse files
authored
[AMD] Improve math.fdiv FTZ lowering for f32 inputs (triton-lang#5474)
This commit lowered math.fdiv to a truely approximated div operation which helps to save register usage and improve performance.
1 parent 3b4d632 commit bde92ef

File tree

2 files changed

+96
-2
lines changed

2 files changed

+96
-2
lines changed

test/Conversion/amd/fdivide.mlir

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
// RUN: triton-opt %s -split-input-file --convert-triton-amdgpu-to-llvm="arch=gfx942" | FileCheck %s
2+
3+
#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [1], order = [0]}>
4+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
5+
tt.func public @test_fdiv_f32(%arg0: tensor<64xf32, #blocked>, %arg1: tensor<64xf32, #blocked>) attributes {noinline = false} {
6+
// CHECK-LABEL: test_fdiv_f32
7+
// CHECK: llvm.amdgcn.div.scale.f32
8+
// CHECK: llvm.amdgcn.div.scale.f32
9+
// CHECK: llvm.amdgcn.rcp.f32
10+
// CHECK: llvm.fmul
11+
// CHECK: llvm.amdgcn.div.fmas.f32
12+
// CHECK: llvm.amdgcn.div.fixup.f32
13+
// CHECK-NOT: llvm.fdiv
14+
%0 = arith.divf %arg0, %arg1 : tensor<64xf32, #blocked>
15+
tt.return
16+
}
17+
}
18+
19+
// -----
20+
21+
#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [1], order = [0]}>
22+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
23+
tt.func public @test_fdiv_f64(%arg0: tensor<64xf64, #blocked>, %arg1: tensor<64xf64, #blocked>) attributes {noinline = false} {
24+
// CHECK-LABEL: test_fdiv_f64
25+
// CHECK: llvm.fdiv
26+
%0 = arith.divf %arg0, %arg1 : tensor<64xf64, #blocked>
27+
tt.return
28+
}
29+
}
30+
31+
// -----
32+
33+
#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [1], order = [0]}>
34+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
35+
tt.func public @test_div_rn(%arg0: tensor<64xf32, #blocked>, %arg1: tensor<64xf32, #blocked>) attributes {noinline = false} {
36+
// CHECK-LABEL: test_div_rn
37+
// CHECK: llvm.fdiv
38+
%0 = tt.precise_divf %arg0, %arg1 : tensor<64xf32, #blocked>
39+
tt.return
40+
}
41+
}

third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1081,9 +1081,62 @@ struct FDivOpConversion
10811081
ConversionPatternRewriter &rewriter,
10821082
Type elemTy, MultipleOperandsRange operands,
10831083
Location loc) const {
1084+
// For non-F32 input, it's lowered to LLVM::FDivOp, which is a
1085+
// IEEE-compliant DIV operation.
1086+
if (elemTy.getIntOrFloatBitWidth() != 32)
1087+
return {rewriter.create<LLVM::FDivOp>(loc, elemTy, operands[0][0],
1088+
operands[0][1])};
1089+
1090+
auto b = TritonLLVMOpBuilder(loc, rewriter);
10841091

1085-
return {rewriter.create<LLVM::FDivOp>(loc, elemTy, operands[0][0],
1086-
operands[0][1])};
1092+
// The algorithm comes from
1093+
// https://github.com/llvm/llvm-project/blob/bda7aadf/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp#L4980-L5065
1094+
// with the Newton-Raphson refinement removed, to perform a faster,
1095+
// approximated DIV operation, aligning with the `div.full.f32` instruction
1096+
// on the NV backend.
1097+
Value &lhs = operands[0][0];
1098+
Value &rhs = operands[0][1];
1099+
MLIRContext *ctx = rewriter.getContext();
1100+
Type divScaleResType = struct_ty({elemTy, i1_ty});
1101+
1102+
// The `llvm.amdgcn.div.scale.f32` instruction's signature is
1103+
// (src0, src1, src2) -> (ret0, ret1), where
1104+
//
1105+
// src0: The numerator or lhs of FDivOp.
1106+
// src1: The denominator or rhs of FDivOp.
1107+
// src2: A boolean indicating which operand to scale. If true, lhs is
1108+
// scaled; Otherwise, rhs is scaled.
1109+
//
1110+
// ret0: The scaled operand.
1111+
// ret1: The VCC register indicating whether post-scaling is required.
1112+
auto denominatorScaleOp = LLVM::createLLVMIntrinsicCallOp(
1113+
rewriter, loc, "llvm.amdgcn.div.scale.f32", divScaleResType,
1114+
{lhs, rhs, b.false_val()});
1115+
Value denominatorScaled = b.extract_val(denominatorScaleOp.getResult(0), 0);
1116+
auto numeratorScaleOp = LLVM::createLLVMIntrinsicCallOp(
1117+
rewriter, loc, "llvm.amdgcn.div.scale.f32", divScaleResType,
1118+
{lhs, rhs, b.true_val()});
1119+
Value numeratorScaled = b.extract_val(numeratorScaleOp.getResult(0), 0);
1120+
Value vcc = b.extract_val(numeratorScaleOp.getResult(0), 1);
1121+
1122+
Value rcp =
1123+
LLVM::createLLVMIntrinsicCallOp(rewriter, loc, "llvm.amdgcn.rcp.f32",
1124+
elemTy, {denominatorScaled})
1125+
.getResult(0);
1126+
1127+
Value approxDiv = b.fmul(numeratorScaled, rcp);
1128+
1129+
// Since the Newton-Raphson is skipped, we use 0 instead of approximations
1130+
// as the inputs.
1131+
auto fmas = LLVM::createLLVMIntrinsicCallOp(
1132+
rewriter, loc, "llvm.amdgcn.div.fmas.f32", elemTy,
1133+
{b.f32_val(0), b.f32_val(0), approxDiv, vcc})
1134+
.getResult(0);
1135+
1136+
return {LLVM::createLLVMIntrinsicCallOp(rewriter, loc,
1137+
"llvm.amdgcn.div.fixup.f32", elemTy,
1138+
{fmas, rhs, lhs})
1139+
.getResult(0)};
10871140
}
10881141
};
10891142

0 commit comments

Comments
 (0)