Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions third_party/intel/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,10 @@ def make_ttir(mod, metadata, opt):
passes.common.add_symbol_dce(pm)
passes.ttir.add_loop_unroll(pm)
pm.run(mod, 'make_ttir')

if intel.has_precise_divide_sqrt(mod):
metadata["build_flags"] = "-cl-fp32-correctly-rounded-divide-sqrt"

return mod

@staticmethod
Expand Down Expand Up @@ -363,16 +367,15 @@ def make_llir(src, metadata, options):
def make_spv(src, metadata, options, device_arch):
spirv, name = intel.translate_to_spirv(src)
metadata["name"] = name
metadata.setdefault("build_flags", "")
if options.grf_mode == 'small':
metadata["build_flags"] = "-cl-intel-128-GRF-per-thread"
metadata["build_flags"] += " -cl-intel-128-GRF-per-thread"
elif options.grf_mode == 'large':
if options.num_warps > 32:
raise RuntimeError("grf_mode = large cannot be used with num_warps > 32")
metadata["build_flags"] = "-cl-intel-256-GRF-per-thread"
metadata["build_flags"] += " -cl-intel-256-GRF-per-thread"
elif options.grf_mode == 'auto':
metadata["build_flags"] = "-cl-intel-enable-auto-large-GRF-mode"
else:
metadata["build_flags"] = ""
metadata["build_flags"] += " -cl-intel-enable-auto-large-GRF-mode"

if knobs.intel.disable_igc_opt:
metadata["build_flags"] += " -cl-opt-disable"
Expand Down
64 changes: 10 additions & 54 deletions third_party/intel/lib/TritonIntelGPUToLLVM/ElementwiseOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -970,8 +970,8 @@ struct ElementwiseOpConversion
ConversionPatternRewriter &rewriter,
Type elemTy, MultipleOperandsRange operands,
Location loc) const {
assert((!getElementType(op.getLhs()).isBF16() &&
!getElementType(op.getRhs()).isBF16()) &&
assert((!getElementType(operands[0][0]).isBF16() &&
!getElementType(operands[0][1]).isBF16()) &&
"unsupported conversion");
return {
rewriter.create<DestOp>(loc, elemTy, operands[0][0], operands[0][1])};
Expand Down Expand Up @@ -1146,57 +1146,11 @@ struct PreciseSqrtOpConversion
ConversionPatternRewriter &rewriter,
Type elemTy, MultipleOperandsRange operands,
Location loc) const {
auto b = TritonLLVMOpBuilder(loc, rewriter);
Value input = operands[0][0];
Type origTy = input.getType();
if (!origTy.isF64())
input = b.fpext(f64_ty, input);
Type funcType = LLVM::LLVMFunctionType::get(f64_ty, {f64_ty});
LLVM::LLVMFuncOp funcOp =
appendOrGetExternFuncOp(rewriter, op, "__imf_sqrt_rn", funcType);
funcOp.setCConv(triton::gpu::intel::getDefaultCConv(op));
LLVM::CallOp callOp =
LLVM::createLLVMCallOp(rewriter, loc, funcOp, {input});
callOp.setCConv(funcOp.getCConv());
Value result = callOp.getResult();
if (!origTy.isF64())
result = rewriter.create<LLVM::FPTruncOp>(loc, origTy, result);
return {result};
}
};

template <typename TritonOp>
struct OpToExternCallConversion
: public ElementwiseOpConversionBase<TritonOp,
OpToExternCallConversion<TritonOp>> {
using Base =
ElementwiseOpConversionBase<TritonOp, OpToExternCallConversion<TritonOp>>;
using Base::Base;
using Adaptor = typename Base::OpAdaptor;

explicit OpToExternCallConversion(LLVMTypeConverter &typeConverter,
ModuleAxisInfoAnalysis &axisAnalysisPass,
StringRef externFuncName,
PatternBenefit benefit)
: Base::ElementwiseOpConversionBase(typeConverter, axisAnalysisPass,
benefit),
funcName(externFuncName) {}

SmallVector<Value> createDestOps(TritonOp op, Adaptor adaptor,
ConversionPatternRewriter &rewriter,
Type elemTy, MultipleOperandsRange operands,
Location loc) const {
Type funcType = getFunctionType(elemTy, operands[0]);
LLVM::LLVMFuncOp funcOp =
appendOrGetExternFuncOp(rewriter, op, funcName, funcType);
funcOp.setCConv(triton::gpu::intel::getDefaultCConv(op));
auto callOp = LLVM::createLLVMCallOp(rewriter, loc, funcOp, operands[0]);
callOp.setCConv(funcOp.getCConv());
return {callOp.getResult()};
// FIXME: Use precise sqrt builtin: #5419
// Rely on `-cl-fp32-correctly-rounded-divide-sqrt` for precise sqrt.
return {rewriter.create<LLVM::SqrtOp>(loc, elemTy, operands[0],
adaptor.getAttributes().getValue())};
}

private:
StringRef funcName;
};

// Following two patterns are copied from the common part to fix-up calling
Expand Down Expand Up @@ -1273,8 +1227,10 @@ void populateElementwiseOpToLLVMPatterns(

patterns.add<PreciseSqrtOpConversion>(typeConverter, axisInfoAnalysis,
benefit);
patterns.add<OpToExternCallConversion<triton::PreciseDivFOp>>(
typeConverter, axisInfoAnalysis, "__imf_fdiv_rn", benefit);
// FIXME: Use precise divide builtin: #5419
// Rely on `-cl-fp32-correctly-rounded-divide-sqrt` for precise divide.
patterns.add<ElementwiseOpConversion<triton::PreciseDivFOp, LLVM::FDivOp>>(
typeConverter, axisInfoAnalysis, benefit);
patterns.add<MulhiUIOpConversion>(typeConverter, axisInfoAnalysis, targetInfo,
benefit);
patterns.add<ExternElementwiseOpConversion>(typeConverter, axisInfoAnalysis,
Expand Down
10 changes: 10 additions & 0 deletions third_party/intel/triton_xpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,16 @@ void init_triton_intel(py::module &&m) {
return py::int_(ret);
});

m.def("has_precise_divide_sqrt", [](mlir::ModuleOp &mod) -> bool {
using namespace mlir;
WalkResult result = mod.walk([&](Operation *op) {
if (isa<mlir::triton::PreciseDivFOp, mlir::triton::PreciseSqrtOp>(op))
return WalkResult::interrupt();
return WalkResult::advance();
});
return result.wasInterrupted();
});

// FIXME: This is for internal experimentation. In the end we will need a
// producer flag (e.g. PyTorch flag) to allow the Triton compiler to use the
// fast math semantics on all arithmetic operations.
Expand Down
Loading