diff --git a/third_party/intel/backend/compiler.py b/third_party/intel/backend/compiler.py index f6525864b2..f442913275 100644 --- a/third_party/intel/backend/compiler.py +++ b/third_party/intel/backend/compiler.py @@ -202,10 +202,6 @@ def make_ttir(mod, metadata, opt): passes.common.add_symbol_dce(pm) passes.ttir.add_loop_unroll(pm) pm.run(mod, 'make_ttir') - - if intel.has_precise_divide_sqrt(mod): - metadata["build_flags"] = "-cl-fp32-correctly-rounded-divide-sqrt" - return mod @staticmethod diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/ElementwiseOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/ElementwiseOpToLLVM.cpp index bbaa610d02..95d4cce003 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -4,6 +4,7 @@ #include "mlir/IR/MLIRContext.h" #include "third_party/intel/include/Dialect/TritonIntelGPU/IR/Utils.h" #include "third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h" +#include "third_party/intel/lib/Utils/Mangling.h" #include "triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h" #include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" #include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" @@ -1135,22 +1136,41 @@ struct AbsFOpConversion } }; -struct PreciseSqrtOpConversion - : ElementwiseOpConversionBase { +template +struct OpToExternCallConversion + : public ElementwiseOpConversionBase> { using Base = - ElementwiseOpConversionBase; + ElementwiseOpConversionBase>; using Base::Base; using Adaptor = typename Base::OpAdaptor; - SmallVector createDestOps(PreciseSqrtOp op, Adaptor adaptor, + explicit OpToExternCallConversion(LLVMTypeConverter &typeConverter, + ModuleAxisInfoAnalysis &axisAnalysisPass, + StringRef externFuncName, + PatternBenefit benefit) + : Base::ElementwiseOpConversionBase(typeConverter, axisAnalysisPass, + benefit), + funcName(externFuncName) {} + + SmallVector createDestOps(TritonOp op, Adaptor adaptor, ConversionPatternRewriter &rewriter, Type elemTy, MultipleOperandsRange operands, Location loc) const { - // FIXME: Use precise sqrt builtin: #5419 - // Rely on `-cl-fp32-correctly-rounded-divide-sqrt` for precise sqrt. - return {rewriter.create(loc, elemTy, operands[0], - adaptor.getAttributes().getValue())}; + Type funcType = getFunctionType(elemTy, operands[0]); + SmallVector operandTypes(ValueRange(operands[0]).getTypes()); + std::string fnName = + mlir::triton::gpu::intel::mangle(funcName, operandTypes); + LLVM::LLVMFuncOp funcOp = + appendOrGetExternFuncOp(rewriter, op, fnName, funcType); + funcOp.setCConv(triton::gpu::intel::getDefaultCConv(op)); + auto callOp = LLVM::createLLVMCallOp(rewriter, loc, funcOp, operands[0]); + callOp.setCConv(funcOp.getCConv()); + return {callOp.getResult()}; } + +private: + StringRef funcName; }; // Following two patterns are copied from the common part to fix-up calling @@ -1225,12 +1245,10 @@ void populateElementwiseOpToLLVMPatterns( ModuleAxisInfoAnalysis &axisInfoAnalysis, const TargetInfoBase &targetInfo, PatternBenefit benefit) { - patterns.add(typeConverter, axisInfoAnalysis, - benefit); - // FIXME: Use precise divide builtin: #5419 - // Rely on `-cl-fp32-correctly-rounded-divide-sqrt` for precise divide. - patterns.add>( - typeConverter, axisInfoAnalysis, benefit); + patterns.add>( + typeConverter, axisInfoAnalysis, "sqrt_cr", benefit); + patterns.add>( + typeConverter, axisInfoAnalysis, "divide_cr", benefit); patterns.add(typeConverter, axisInfoAnalysis, targetInfo, benefit); patterns.add(typeConverter, axisInfoAnalysis, diff --git a/third_party/intel/triton_xpu.cc b/third_party/intel/triton_xpu.cc index 9316368a6a..628736a9ed 100644 --- a/third_party/intel/triton_xpu.cc +++ b/third_party/intel/triton_xpu.cc @@ -301,16 +301,6 @@ void init_triton_intel(py::module &&m) { return py::int_(ret); }); - m.def("has_precise_divide_sqrt", [](mlir::ModuleOp &mod) -> bool { - using namespace mlir; - WalkResult result = mod.walk([&](Operation *op) { - if (isa(op)) - return WalkResult::interrupt(); - return WalkResult::advance(); - }); - return result.wasInterrupted(); - }); - // FIXME: This is for internal experimentation. In the end we will need a // producer flag (e.g. PyTorch flag) to allow the Triton compiler to use the // fast math semantics on all arithmetic operations.