Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ def forward(q, k, v, causal, sm_scale):
num_warps=num_warps, #
num_stages=num_stages, #
grf_mode='large', #
advanced_path=True, #
)
return o

Expand Down
8 changes: 5 additions & 3 deletions third_party/intel/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class XPUOptions:
backend_name: str = 'intel'
sanitize_overflow: bool = False
generate_native_code: bool = False
advanced_path: bool = False

def __post_init__(self):
default_libdir = Path(__file__).parent / 'lib'
Expand Down Expand Up @@ -232,8 +233,9 @@ def make_ttgir(mod, metadata, opt, properties):
pm = ir.pass_manager(mod.context)
pm.enable_debug()

if (properties["has_subgroup_2d_block_io"] and properties["has_subgroup_matrix_multiply_accumulate"]
and os.getenv("TRITON_INTEL_ADVANCED_PATH", "0") == "1"):
if (os.getenv("TRITON_INTEL_ADVANCED_PATH", "0") == "1" or opt.advanced_path):
if not (properties["has_subgroup_2d_block_io"] and properties["has_subgroup_matrix_multiply_accumulate"]):
raise AssertionError("Target do not support blocked load/mma")
return XPUBackend.AdvancedPath.make_ttgir(mod, metadata, opt)

passes.ttir.add_convert_to_ttgpuir(pm, "xpu", opt.num_warps, opt.threads_per_warp, opt.num_ctas)
Expand Down Expand Up @@ -291,7 +293,7 @@ def make_llir(src, metadata, options):
# being used, e.g., convert_layout.
if os.getenv("TRITON_INTEL_REDUCE_TRANSPOSE", "0") != "1":
intel.passes.ttgpuir.add_allocate_shared_memory(pm)
intel.passes.ttgpuir.add_to_llvmir(pm)
intel.passes.ttgpuir.add_to_llvmir(pm, options.advanced_path)
intel.set_fast_math(mod)
passes.convert.add_arith_to_llvmir(pm)
passes.common.add_canonicalizer(pm)
Expand Down
5 changes: 5 additions & 0 deletions third_party/intel/include/TritonIntelGPUToLLVM/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@ def ConvertTritonIntelGPUToLLVM
"mlir::triton::TritonDialect",
"mlir::triton::gpu::TritonGPUDialect",
"mlir::triton::TritonGEN::TritonGENDialect"];
let options = [
Option<"advancedPath", "advanced_path",
"bool", /*default*/"false",
"enable advanced path">,
];
}

#endif // TRITONINTELGPU_CONVERSION_PASSES
10 changes: 2 additions & 8 deletions third_party/intel/lib/TritonIntelGPUToLLVM/PipelineManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -180,14 +180,8 @@ struct AddSPIRVEnvPattern : public mlir::OpRewritePattern<ModuleOp> {
/// block pointers or not.
class TritonGPUToLLVMPipelineManager {
public:
TritonGPUToLLVMPipelineManager(ModuleOp &mod, MLIRContext *ctx)
: mod(mod), ctx(ctx),
isAdvancedPathEnabled(
mod->hasAttr(gpu::intel::TritonIntelGPUDialect::
getSupportSG2DBlockAttrName()) &&
mod->hasAttr(
gpu::intel::TritonIntelGPUDialect::getSupportDPASAttrName()) &&
mlir::triton::tools::getBoolEnv("TRITON_INTEL_ADVANCED_PATH")) {}
TritonGPUToLLVMPipelineManager(ModuleOp &mod, MLIRContext *ctx, bool advanced)
: mod(mod), ctx(ctx), isAdvancedPathEnabled(advanced) {}

/// FIXME: remove once the block ptr conversion path is capable of handling
/// shared memory.
Expand Down
21 changes: 14 additions & 7 deletions third_party/intel/lib/TritonIntelGPUToLLVM/TritonGPUToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ struct ConvertTritonGPUToLLVM
: public triton::gpu::intel::impl::ConvertTritonIntelGPUToLLVMBase<
ConvertTritonGPUToLLVM> {
using ConvertTritonIntelGPUToLLVMBase::ConvertTritonIntelGPUToLLVMBase;
ConvertTritonGPUToLLVM() = default;
ConvertTritonGPUToLLVM(bool advancedPath) {
this->advancedPath = advancedPath;
}

void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<LLVM::LLVMDialect, TritonGEN::TritonGENDialect,
Expand All @@ -78,15 +82,18 @@ struct ConvertTritonGPUToLLVM
MLIRContext *context = &getContext();
ModuleOp mod = getOperation();

bool isAdvancedPathEnabled =
mlir::triton::tools::getBoolEnv("TRITON_INTEL_ADVANCED_PATH") ||
advancedPath;
if (isAdvancedPathEnabled)
assert(mod->hasAttr(triton::gpu::intel::TritonIntelGPUDialect::
getSupportSG2DBlockAttrName()) &&
mod->hasAttr(triton::gpu::intel::TritonIntelGPUDialect::
getSupportDPASAttrName()) &&
"Target do not support blocked load/mma");
mlir::triton::intel::TritonGPUToLLVMPipelineManager pipelineManager(
mod, context);
mod, context, isAdvancedPathEnabled);
mlir::LowerToLLVMOptions option(context);
bool isAdvancedPathEnabled =
mod->hasAttr(triton::gpu::intel::TritonIntelGPUDialect::
getSupportSG2DBlockAttrName()) &&
mod->hasAttr(triton::gpu::intel::TritonIntelGPUDialect::
getSupportDPASAttrName()) &&
mlir::triton::tools::getBoolEnv("TRITON_INTEL_ADVANCED_PATH");
mlir::triton::intel::TargetInfo targetInfo;
TritonIntelGPUToLLVMTypeConverter typeConverter(context, option, targetInfo,
isAdvancedPathEnabled);
Expand Down
4 changes: 2 additions & 2 deletions third_party/intel/triton_xpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ void init_triton_intel_passes_ttir(py::module &&m) {
}

void init_triton_intel_passes_ttgpuir(py::module &&m) {
ADD_PASS_WRAPPER_0("add_to_llvmir",
gpu::intel::createConvertTritonIntelGPUToLLVM);
ADD_PASS_WRAPPER_OPT_1("add_to_llvmir",
gpu::intel::createConvertTritonIntelGPUToLLVM, bool);
ADD_PASS_WRAPPER_0("add_accelerate_matmul",
gpu::intel::createTritonIntelGPUAccelerateMatmul);
ADD_PASS_WRAPPER_0("add_decompose_unsupported_conversions",
Expand Down