diff --git a/third_party/intel/backend/compiler.py b/third_party/intel/backend/compiler.py index 943c8a8ea6..f61bcefea7 100644 --- a/third_party/intel/backend/compiler.py +++ b/third_party/intel/backend/compiler.py @@ -201,6 +201,10 @@ def make_ttir(mod, metadata, opt): passes.common.add_symbol_dce(pm) passes.ttir.add_loop_unroll(pm) pm.run(mod, 'make_ttir') + + if intel.has_precise_divide_sqrt(mod): + metadata["build_flags"] = "-cl-fp32-correctly-rounded-divide-sqrt" + return mod @staticmethod @@ -363,16 +367,15 @@ def make_llir(src, metadata, options): def make_spv(src, metadata, options, device_arch): spirv, name = intel.translate_to_spirv(src) metadata["name"] = name + metadata.setdefault("build_flags", "") if options.grf_mode == 'small': - metadata["build_flags"] = "-cl-intel-128-GRF-per-thread" + metadata["build_flags"] += " -cl-intel-128-GRF-per-thread" elif options.grf_mode == 'large': if options.num_warps > 32: raise RuntimeError("grf_mode = large cannot be used with num_warps > 32") - metadata["build_flags"] = "-cl-intel-256-GRF-per-thread" + metadata["build_flags"] += " -cl-intel-256-GRF-per-thread" elif options.grf_mode == 'auto': - metadata["build_flags"] = "-cl-intel-enable-auto-large-GRF-mode" - else: - metadata["build_flags"] = "" + metadata["build_flags"] += " -cl-intel-enable-auto-large-GRF-mode" if knobs.intel.disable_igc_opt: metadata["build_flags"] += " -cl-opt-disable" diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/ElementwiseOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/ElementwiseOpToLLVM.cpp index ce4bf81606..bbaa610d02 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -970,8 +970,8 @@ struct ElementwiseOpConversion ConversionPatternRewriter &rewriter, Type elemTy, MultipleOperandsRange operands, Location loc) const { - assert((!getElementType(op.getLhs()).isBF16() && - !getElementType(op.getRhs()).isBF16()) && + assert((!getElementType(operands[0][0]).isBF16() && + !getElementType(operands[0][1]).isBF16()) && "unsupported conversion"); return { rewriter.create(loc, elemTy, operands[0][0], operands[0][1])}; @@ -1146,57 +1146,11 @@ struct PreciseSqrtOpConversion ConversionPatternRewriter &rewriter, Type elemTy, MultipleOperandsRange operands, Location loc) const { - auto b = TritonLLVMOpBuilder(loc, rewriter); - Value input = operands[0][0]; - Type origTy = input.getType(); - if (!origTy.isF64()) - input = b.fpext(f64_ty, input); - Type funcType = LLVM::LLVMFunctionType::get(f64_ty, {f64_ty}); - LLVM::LLVMFuncOp funcOp = - appendOrGetExternFuncOp(rewriter, op, "__imf_sqrt_rn", funcType); - funcOp.setCConv(triton::gpu::intel::getDefaultCConv(op)); - LLVM::CallOp callOp = - LLVM::createLLVMCallOp(rewriter, loc, funcOp, {input}); - callOp.setCConv(funcOp.getCConv()); - Value result = callOp.getResult(); - if (!origTy.isF64()) - result = rewriter.create(loc, origTy, result); - return {result}; - } -}; - -template -struct OpToExternCallConversion - : public ElementwiseOpConversionBase> { - using Base = - ElementwiseOpConversionBase>; - using Base::Base; - using Adaptor = typename Base::OpAdaptor; - - explicit OpToExternCallConversion(LLVMTypeConverter &typeConverter, - ModuleAxisInfoAnalysis &axisAnalysisPass, - StringRef externFuncName, - PatternBenefit benefit) - : Base::ElementwiseOpConversionBase(typeConverter, axisAnalysisPass, - benefit), - funcName(externFuncName) {} - - SmallVector createDestOps(TritonOp op, Adaptor adaptor, - ConversionPatternRewriter &rewriter, - Type elemTy, MultipleOperandsRange operands, - Location loc) const { - Type funcType = getFunctionType(elemTy, operands[0]); - LLVM::LLVMFuncOp funcOp = - appendOrGetExternFuncOp(rewriter, op, funcName, funcType); - funcOp.setCConv(triton::gpu::intel::getDefaultCConv(op)); - auto callOp = LLVM::createLLVMCallOp(rewriter, loc, funcOp, operands[0]); - callOp.setCConv(funcOp.getCConv()); - return {callOp.getResult()}; + // FIXME: Use precise sqrt builtin: #5419 + // Rely on `-cl-fp32-correctly-rounded-divide-sqrt` for precise sqrt. + return {rewriter.create(loc, elemTy, operands[0], + adaptor.getAttributes().getValue())}; } - -private: - StringRef funcName; }; // Following two patterns are copied from the common part to fix-up calling @@ -1273,8 +1227,10 @@ void populateElementwiseOpToLLVMPatterns( patterns.add(typeConverter, axisInfoAnalysis, benefit); - patterns.add>( - typeConverter, axisInfoAnalysis, "__imf_fdiv_rn", benefit); + // FIXME: Use precise divide builtin: #5419 + // Rely on `-cl-fp32-correctly-rounded-divide-sqrt` for precise divide. + patterns.add>( + typeConverter, axisInfoAnalysis, benefit); patterns.add(typeConverter, axisInfoAnalysis, targetInfo, benefit); patterns.add(typeConverter, axisInfoAnalysis, diff --git a/third_party/intel/triton_xpu.cc b/third_party/intel/triton_xpu.cc index 53d6b885a9..6a5084981f 100644 --- a/third_party/intel/triton_xpu.cc +++ b/third_party/intel/triton_xpu.cc @@ -299,6 +299,16 @@ void init_triton_intel(py::module &&m) { return py::int_(ret); }); + m.def("has_precise_divide_sqrt", [](mlir::ModuleOp &mod) -> bool { + using namespace mlir; + WalkResult result = mod.walk([&](Operation *op) { + if (isa(op)) + return WalkResult::interrupt(); + return WalkResult::advance(); + }); + return result.wasInterrupted(); + }); + // FIXME: This is for internal experimentation. In the end we will need a // producer flag (e.g. PyTorch flag) to allow the Triton compiler to use the // fast math semantics on all arithmetic operations.