@@ -1081,9 +1081,62 @@ struct FDivOpConversion
10811081 ConversionPatternRewriter &rewriter,
10821082 Type elemTy, MultipleOperandsRange operands,
10831083 Location loc) const {
1084+ // For non-F32 input, it's lowered to LLVM::FDivOp, which is a
1085+ // IEEE-compliant DIV operation.
1086+ if (elemTy.getIntOrFloatBitWidth () != 32 )
1087+ return {rewriter.create <LLVM::FDivOp>(loc, elemTy, operands[0 ][0 ],
1088+ operands[0 ][1 ])};
1089+
1090+ auto b = TritonLLVMOpBuilder (loc, rewriter);
10841091
1085- return {rewriter.create <LLVM::FDivOp>(loc, elemTy, operands[0 ][0 ],
1086- operands[0 ][1 ])};
1092+ // The algorithm comes from
1093+ // https://github.com/llvm/llvm-project/blob/bda7aadf/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp#L4980-L5065
1094+ // with the Newton-Raphson refinement removed, to perform a faster,
1095+ // approximated DIV operation, aligning with the `div.full.f32` instruction
1096+ // on the NV backend.
1097+ Value &lhs = operands[0 ][0 ];
1098+ Value &rhs = operands[0 ][1 ];
1099+ MLIRContext *ctx = rewriter.getContext ();
1100+ Type divScaleResType = struct_ty ({elemTy, i1_ty});
1101+
1102+ // The `llvm.amdgcn.div.scale.f32` instruction's signature is
1103+ // (src0, src1, src2) -> (ret0, ret1), where
1104+ //
1105+ // src0: The numerator or lhs of FDivOp.
1106+ // src1: The denominator or rhs of FDivOp.
1107+ // src2: A boolean indicating which operand to scale. If true, lhs is
1108+ // scaled; Otherwise, rhs is scaled.
1109+ //
1110+ // ret0: The scaled operand.
1111+ // ret1: The VCC register indicating whether post-scaling is required.
1112+ auto denominatorScaleOp = LLVM::createLLVMIntrinsicCallOp (
1113+ rewriter, loc, " llvm.amdgcn.div.scale.f32" , divScaleResType,
1114+ {lhs, rhs, b.false_val ()});
1115+ Value denominatorScaled = b.extract_val (denominatorScaleOp.getResult (0 ), 0 );
1116+ auto numeratorScaleOp = LLVM::createLLVMIntrinsicCallOp (
1117+ rewriter, loc, " llvm.amdgcn.div.scale.f32" , divScaleResType,
1118+ {lhs, rhs, b.true_val ()});
1119+ Value numeratorScaled = b.extract_val (numeratorScaleOp.getResult (0 ), 0 );
1120+ Value vcc = b.extract_val (numeratorScaleOp.getResult (0 ), 1 );
1121+
1122+ Value rcp =
1123+ LLVM::createLLVMIntrinsicCallOp (rewriter, loc, " llvm.amdgcn.rcp.f32" ,
1124+ elemTy, {denominatorScaled})
1125+ .getResult (0 );
1126+
1127+ Value approxDiv = b.fmul (numeratorScaled, rcp);
1128+
1129+ // Since the Newton-Raphson is skipped, we use 0 instead of approximations
1130+ // as the inputs.
1131+ auto fmas = LLVM::createLLVMIntrinsicCallOp (
1132+ rewriter, loc, " llvm.amdgcn.div.fmas.f32" , elemTy,
1133+ {b.f32_val (0 ), b.f32_val (0 ), approxDiv, vcc})
1134+ .getResult (0 );
1135+
1136+ return {LLVM::createLLVMIntrinsicCallOp (rewriter, loc,
1137+ " llvm.amdgcn.div.fixup.f32" , elemTy,
1138+ {fmas, rhs, lhs})
1139+ .getResult (0 )};
10871140 }
10881141};
10891142
0 commit comments