Skip to content

Commit 04fcbe5

Browse files
Use OCL builtin for precise divide and sqrt (#5432)
Note that they are not official interfaces. Fixes #5419 Signed-off-by: Whitney Tsang <[email protected]>
1 parent a9362d2 commit 04fcbe5

File tree

3 files changed

+32
-28
lines changed

3 files changed

+32
-28
lines changed

third_party/intel/backend/compiler.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -202,10 +202,6 @@ def make_ttir(mod, metadata, opt):
202202
passes.common.add_symbol_dce(pm)
203203
passes.ttir.add_loop_unroll(pm)
204204
pm.run(mod, 'make_ttir')
205-
206-
if intel.has_precise_divide_sqrt(mod):
207-
metadata["build_flags"] = "-cl-fp32-correctly-rounded-divide-sqrt"
208-
209205
return mod
210206

211207
@staticmethod

third_party/intel/lib/TritonIntelGPUToLLVM/ElementwiseOpToLLVM.cpp

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "mlir/IR/MLIRContext.h"
55
#include "third_party/intel/include/Dialect/TritonIntelGPU/IR/Utils.h"
66
#include "third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h"
7+
#include "third_party/intel/lib/Utils/Mangling.h"
78
#include "triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h"
89
#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h"
910
#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h"
@@ -1135,22 +1136,41 @@ struct AbsFOpConversion
11351136
}
11361137
};
11371138

1138-
struct PreciseSqrtOpConversion
1139-
: ElementwiseOpConversionBase<PreciseSqrtOp, PreciseSqrtOpConversion> {
1139+
template <typename TritonOp>
1140+
struct OpToExternCallConversion
1141+
: public ElementwiseOpConversionBase<TritonOp,
1142+
OpToExternCallConversion<TritonOp>> {
11401143
using Base =
1141-
ElementwiseOpConversionBase<PreciseSqrtOp, PreciseSqrtOpConversion>;
1144+
ElementwiseOpConversionBase<TritonOp, OpToExternCallConversion<TritonOp>>;
11421145
using Base::Base;
11431146
using Adaptor = typename Base::OpAdaptor;
11441147

1145-
SmallVector<Value> createDestOps(PreciseSqrtOp op, Adaptor adaptor,
1148+
explicit OpToExternCallConversion(LLVMTypeConverter &typeConverter,
1149+
ModuleAxisInfoAnalysis &axisAnalysisPass,
1150+
StringRef externFuncName,
1151+
PatternBenefit benefit)
1152+
: Base::ElementwiseOpConversionBase(typeConverter, axisAnalysisPass,
1153+
benefit),
1154+
funcName(externFuncName) {}
1155+
1156+
SmallVector<Value> createDestOps(TritonOp op, Adaptor adaptor,
11461157
ConversionPatternRewriter &rewriter,
11471158
Type elemTy, MultipleOperandsRange operands,
11481159
Location loc) const {
1149-
// FIXME: Use precise sqrt builtin: #5419
1150-
// Rely on `-cl-fp32-correctly-rounded-divide-sqrt` for precise sqrt.
1151-
return {rewriter.create<LLVM::SqrtOp>(loc, elemTy, operands[0],
1152-
adaptor.getAttributes().getValue())};
1160+
Type funcType = getFunctionType(elemTy, operands[0]);
1161+
SmallVector<Type> operandTypes(ValueRange(operands[0]).getTypes());
1162+
std::string fnName =
1163+
mlir::triton::gpu::intel::mangle(funcName, operandTypes);
1164+
LLVM::LLVMFuncOp funcOp =
1165+
appendOrGetExternFuncOp(rewriter, op, fnName, funcType);
1166+
funcOp.setCConv(triton::gpu::intel::getDefaultCConv(op));
1167+
auto callOp = LLVM::createLLVMCallOp(rewriter, loc, funcOp, operands[0]);
1168+
callOp.setCConv(funcOp.getCConv());
1169+
return {callOp.getResult()};
11531170
}
1171+
1172+
private:
1173+
StringRef funcName;
11541174
};
11551175

11561176
// Following two patterns are copied from the common part to fix-up calling
@@ -1225,12 +1245,10 @@ void populateElementwiseOpToLLVMPatterns(
12251245
ModuleAxisInfoAnalysis &axisInfoAnalysis, const TargetInfoBase &targetInfo,
12261246
PatternBenefit benefit) {
12271247

1228-
patterns.add<PreciseSqrtOpConversion>(typeConverter, axisInfoAnalysis,
1229-
benefit);
1230-
// FIXME: Use precise divide builtin: #5419
1231-
// Rely on `-cl-fp32-correctly-rounded-divide-sqrt` for precise divide.
1232-
patterns.add<ElementwiseOpConversion<triton::PreciseDivFOp, LLVM::FDivOp>>(
1233-
typeConverter, axisInfoAnalysis, benefit);
1248+
patterns.add<OpToExternCallConversion<triton::PreciseSqrtOp>>(
1249+
typeConverter, axisInfoAnalysis, "sqrt_cr", benefit);
1250+
patterns.add<OpToExternCallConversion<triton::PreciseDivFOp>>(
1251+
typeConverter, axisInfoAnalysis, "divide_cr", benefit);
12341252
patterns.add<MulhiUIOpConversion>(typeConverter, axisInfoAnalysis, targetInfo,
12351253
benefit);
12361254
patterns.add<ExternElementwiseOpConversion>(typeConverter, axisInfoAnalysis,

third_party/intel/triton_xpu.cc

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -301,16 +301,6 @@ void init_triton_intel(py::module &&m) {
301301
return py::int_(ret);
302302
});
303303

304-
m.def("has_precise_divide_sqrt", [](mlir::ModuleOp &mod) -> bool {
305-
using namespace mlir;
306-
WalkResult result = mod.walk([&](Operation *op) {
307-
if (isa<mlir::triton::PreciseDivFOp, mlir::triton::PreciseSqrtOp>(op))
308-
return WalkResult::interrupt();
309-
return WalkResult::advance();
310-
});
311-
return result.wasInterrupted();
312-
});
313-
314304
// FIXME: This is for internal experimentation. In the end we will need a
315305
// producer flag (e.g. PyTorch flag) to allow the Triton compiler to use the
316306
// fast math semantics on all arithmetic operations.

0 commit comments

Comments
 (0)