@@ -1408,62 +1408,9 @@ struct FDivOpConversion
1408
1408
ConversionPatternRewriter &rewriter,
1409
1409
Type elemTy, MultipleOperandsRange operands,
1410
1410
Location loc) const {
1411
- // For non-F32 input, it's lowered to LLVM::FDivOp, which is a
1412
- // IEEE-compliant DIV operation.
1413
- if (elemTy.getIntOrFloatBitWidth () != 32 )
1414
- return {rewriter.create <LLVM::FDivOp>(loc, elemTy, operands[0 ][0 ],
1415
- operands[0 ][1 ])};
1416
-
1417
- auto b = TritonLLVMOpBuilder (loc, rewriter);
1418
1411
1419
- // The algorithm comes from
1420
- // https://github.com/llvm/llvm-project/blob/bda7aadf/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp#L4980-L5065
1421
- // with the Newton-Raphson refinement removed, to perform a faster,
1422
- // approximated DIV operation, aligning with the `div.full.f32` instruction
1423
- // on the NV backend.
1424
- Value &lhs = operands[0 ][0 ];
1425
- Value &rhs = operands[0 ][1 ];
1426
- MLIRContext *ctx = rewriter.getContext ();
1427
- Type divScaleResType = struct_ty ({elemTy, i1_ty});
1428
-
1429
- // The `llvm.amdgcn.div.scale.f32` instruction's signature is
1430
- // (src0, src1, src2) -> (ret0, ret1), where
1431
- //
1432
- // src0: The numerator or lhs of FDivOp.
1433
- // src1: The denominator or rhs of FDivOp.
1434
- // src2: A boolean indicating which operand to scale. If true, lhs is
1435
- // scaled; Otherwise, rhs is scaled.
1436
- //
1437
- // ret0: The scaled operand.
1438
- // ret1: The VCC register indicating whether post-scaling is required.
1439
- auto denominatorScaleOp = LLVM::createLLVMIntrinsicCallOp (
1440
- rewriter, loc, " llvm.amdgcn.div.scale.f32" , divScaleResType,
1441
- {lhs, rhs, b.false_val ()});
1442
- Value denominatorScaled = b.extract_val (denominatorScaleOp.getResult (0 ), 0 );
1443
- auto numeratorScaleOp = LLVM::createLLVMIntrinsicCallOp (
1444
- rewriter, loc, " llvm.amdgcn.div.scale.f32" , divScaleResType,
1445
- {lhs, rhs, b.true_val ()});
1446
- Value numeratorScaled = b.extract_val (numeratorScaleOp.getResult (0 ), 0 );
1447
- Value vcc = b.extract_val (numeratorScaleOp.getResult (0 ), 1 );
1448
-
1449
- Value rcp =
1450
- LLVM::createLLVMIntrinsicCallOp (rewriter, loc, " llvm.amdgcn.rcp.f32" ,
1451
- elemTy, {denominatorScaled})
1452
- .getResult (0 );
1453
-
1454
- Value approxDiv = b.fmul (numeratorScaled, rcp);
1455
-
1456
- // Since the Newton-Raphson is skipped, we use 0 instead of approximations
1457
- // as the inputs.
1458
- auto fmas = LLVM::createLLVMIntrinsicCallOp (
1459
- rewriter, loc, " llvm.amdgcn.div.fmas.f32" , elemTy,
1460
- {b.f32_val (0 ), b.f32_val (0 ), approxDiv, vcc})
1461
- .getResult (0 );
1462
-
1463
- return {LLVM::createLLVMIntrinsicCallOp (rewriter, loc,
1464
- " llvm.amdgcn.div.fixup.f32" , elemTy,
1465
- {fmas, rhs, lhs})
1466
- .getResult (0 )};
1412
+ return {rewriter.create <LLVM::FDivOp>(loc, elemTy, operands[0 ][0 ],
1413
+ operands[0 ][1 ])};
1467
1414
}
1468
1415
};
1469
1416
0 commit comments