Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions third_party/intel/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,10 +202,6 @@ def make_ttir(mod, metadata, opt):
passes.common.add_symbol_dce(pm)
passes.ttir.add_loop_unroll(pm)
pm.run(mod, 'make_ttir')

if intel.has_precise_divide_sqrt(mod):
metadata["build_flags"] = "-cl-fp32-correctly-rounded-divide-sqrt"

return mod

@staticmethod
Expand Down
46 changes: 32 additions & 14 deletions third_party/intel/lib/TritonIntelGPUToLLVM/ElementwiseOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "mlir/IR/MLIRContext.h"
#include "third_party/intel/include/Dialect/TritonIntelGPU/IR/Utils.h"
#include "third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h"
#include "third_party/intel/lib/Utils/Mangling.h"
#include "triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h"
#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h"
#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h"
Expand Down Expand Up @@ -1135,22 +1136,41 @@ struct AbsFOpConversion
}
};

struct PreciseSqrtOpConversion
: ElementwiseOpConversionBase<PreciseSqrtOp, PreciseSqrtOpConversion> {
template <typename TritonOp>
struct OpToExternCallConversion
: public ElementwiseOpConversionBase<TritonOp,
OpToExternCallConversion<TritonOp>> {
using Base =
ElementwiseOpConversionBase<PreciseSqrtOp, PreciseSqrtOpConversion>;
ElementwiseOpConversionBase<TritonOp, OpToExternCallConversion<TritonOp>>;
using Base::Base;
using Adaptor = typename Base::OpAdaptor;

SmallVector<Value> createDestOps(PreciseSqrtOp op, Adaptor adaptor,
explicit OpToExternCallConversion(LLVMTypeConverter &typeConverter,
ModuleAxisInfoAnalysis &axisAnalysisPass,
StringRef externFuncName,
PatternBenefit benefit)
: Base::ElementwiseOpConversionBase(typeConverter, axisAnalysisPass,
benefit),
funcName(externFuncName) {}

SmallVector<Value> createDestOps(TritonOp op, Adaptor adaptor,
ConversionPatternRewriter &rewriter,
Type elemTy, MultipleOperandsRange operands,
Location loc) const {
// FIXME: Use precise sqrt builtin: #5419
// Rely on `-cl-fp32-correctly-rounded-divide-sqrt` for precise sqrt.
return {rewriter.create<LLVM::SqrtOp>(loc, elemTy, operands[0],
adaptor.getAttributes().getValue())};
Type funcType = getFunctionType(elemTy, operands[0]);
SmallVector<Type> operandTypes(ValueRange(operands[0]).getTypes());
std::string fnName =
mlir::triton::gpu::intel::mangle(funcName, operandTypes);
LLVM::LLVMFuncOp funcOp =
appendOrGetExternFuncOp(rewriter, op, fnName, funcType);
funcOp.setCConv(triton::gpu::intel::getDefaultCConv(op));
auto callOp = LLVM::createLLVMCallOp(rewriter, loc, funcOp, operands[0]);
callOp.setCConv(funcOp.getCConv());
return {callOp.getResult()};
}

private:
StringRef funcName;
};

// Following two patterns are copied from the common part to fix-up calling
Expand Down Expand Up @@ -1225,12 +1245,10 @@ void populateElementwiseOpToLLVMPatterns(
ModuleAxisInfoAnalysis &axisInfoAnalysis, const TargetInfoBase &targetInfo,
PatternBenefit benefit) {

patterns.add<PreciseSqrtOpConversion>(typeConverter, axisInfoAnalysis,
benefit);
// FIXME: Use precise divide builtin: #5419
// Rely on `-cl-fp32-correctly-rounded-divide-sqrt` for precise divide.
patterns.add<ElementwiseOpConversion<triton::PreciseDivFOp, LLVM::FDivOp>>(
typeConverter, axisInfoAnalysis, benefit);
patterns.add<OpToExternCallConversion<triton::PreciseSqrtOp>>(
typeConverter, axisInfoAnalysis, "sqrt_cr", benefit);
patterns.add<OpToExternCallConversion<triton::PreciseDivFOp>>(
typeConverter, axisInfoAnalysis, "divide_cr", benefit);
patterns.add<MulhiUIOpConversion>(typeConverter, axisInfoAnalysis, targetInfo,
benefit);
patterns.add<ExternElementwiseOpConversion>(typeConverter, axisInfoAnalysis,
Expand Down
10 changes: 0 additions & 10 deletions third_party/intel/triton_xpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -301,16 +301,6 @@ void init_triton_intel(py::module &&m) {
return py::int_(ret);
});

m.def("has_precise_divide_sqrt", [](mlir::ModuleOp &mod) -> bool {
using namespace mlir;
WalkResult result = mod.walk([&](Operation *op) {
if (isa<mlir::triton::PreciseDivFOp, mlir::triton::PreciseSqrtOp>(op))
return WalkResult::interrupt();
return WalkResult::advance();
});
return result.wasInterrupted();
});

// FIXME: This is for internal experimentation. In the end we will need a
// producer flag (e.g. PyTorch flag) to allow the Triton compiler to use the
// fast math semantics on all arithmetic operations.
Expand Down
Loading