From c4274992473fc67306c2212be806cd57d563e659 Mon Sep 17 00:00:00 2001 From: Dewei Wang Date: Tue, 12 Nov 2024 22:16:27 -0800 Subject: [PATCH 1/5] add xpu option to enable advanced path --- .../flash_attention_fwd_benchmark.py | 1 + third_party/intel/backend/compiler.py | 5 +++-- .../intel/include/TritonIntelGPUToLLVM/Passes.td | 5 +++++ .../lib/TritonIntelGPUToLLVM/PipelineManager.h | 10 ++-------- .../lib/TritonIntelGPUToLLVM/TritonGPUToLLVM.cpp | 13 +++++++++---- third_party/intel/triton_xpu.cc | 4 ++-- 6 files changed, 22 insertions(+), 16 deletions(-) diff --git a/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py b/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py index a31290d850..73668bfb2e 100644 --- a/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py @@ -214,6 +214,7 @@ def forward(q, k, v, causal, sm_scale): num_warps=num_warps, # num_stages=num_stages, # grf_mode='large', # + advanced_path=True, ) return o diff --git a/third_party/intel/backend/compiler.py b/third_party/intel/backend/compiler.py index 50301edd1d..a9e6e78c83 100644 --- a/third_party/intel/backend/compiler.py +++ b/third_party/intel/backend/compiler.py @@ -56,6 +56,7 @@ class XPUOptions: backend_name: str = 'intel' sanitize_overflow: bool = False generate_native_code: bool = False + advanced_path: bool = False def __post_init__(self): default_libdir = Path(__file__).parent / 'lib' @@ -233,7 +234,7 @@ def make_ttgir(mod, metadata, opt, properties): pm.enable_debug() if (properties["has_subgroup_2d_block_io"] and properties["has_subgroup_matrix_multiply_accumulate"] - and os.getenv("TRITON_INTEL_ADVANCED_PATH", "0") == "1"): + and (os.getenv("TRITON_INTEL_ADVANCED_PATH", "0") == "1" or opt.advanced_path)): return XPUBackend.AdvancedPath.make_ttgir(mod, metadata, opt) passes.ttir.add_convert_to_ttgpuir(pm, "xpu", opt.num_warps, opt.threads_per_warp, opt.num_ctas) @@ -291,7 +292,7 @@ def make_llir(src, metadata, options): # being used, e.g., convert_layout. if os.getenv("TRITON_INTEL_REDUCE_TRANSPOSE", "0") != "1": intel.passes.ttgpuir.add_allocate_shared_memory(pm) - intel.passes.ttgpuir.add_to_llvmir(pm) + intel.passes.ttgpuir.add_to_llvmir(pm, options.advanced_path) intel.set_fast_math(mod) passes.convert.add_arith_to_llvmir(pm) passes.common.add_canonicalizer(pm) diff --git a/third_party/intel/include/TritonIntelGPUToLLVM/Passes.td b/third_party/intel/include/TritonIntelGPUToLLVM/Passes.td index 16c52b703d..95f55296ed 100644 --- a/third_party/intel/include/TritonIntelGPUToLLVM/Passes.td +++ b/third_party/intel/include/TritonIntelGPUToLLVM/Passes.td @@ -27,6 +27,11 @@ def ConvertTritonIntelGPUToLLVM "mlir::triton::TritonDialect", "mlir::triton::gpu::TritonGPUDialect", "mlir::triton::TritonGEN::TritonGENDialect"]; + let options = [ + Option<"advancedPath", "advanced_path", + "bool", /*default*/"false", + "enable advanced path">, + ]; } #endif // TRITONINTELGPU_CONVERSION_PASSES diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/PipelineManager.h b/third_party/intel/lib/TritonIntelGPUToLLVM/PipelineManager.h index b52b3a3b97..dc60c62656 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/PipelineManager.h +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/PipelineManager.h @@ -180,14 +180,8 @@ struct AddSPIRVEnvPattern : public mlir::OpRewritePattern { /// block pointers or not. class TritonGPUToLLVMPipelineManager { public: - TritonGPUToLLVMPipelineManager(ModuleOp &mod, MLIRContext *ctx) - : mod(mod), ctx(ctx), - isAdvancedPathEnabled( - mod->hasAttr(gpu::intel::TritonIntelGPUDialect:: - getSupportSG2DBlockAttrName()) && - mod->hasAttr( - gpu::intel::TritonIntelGPUDialect::getSupportDPASAttrName()) && - mlir::triton::tools::getBoolEnv("TRITON_INTEL_ADVANCED_PATH")) {} + TritonGPUToLLVMPipelineManager(ModuleOp &mod, MLIRContext *ctx, bool advanced) + : mod(mod), ctx(ctx), isAdvancedPathEnabled(advanced) {} /// FIXME: remove once the block ptr conversion path is capable of handling /// shared memory. diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/TritonGPUToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/TritonGPUToLLVM.cpp index a4c2da184e..cce4829e0e 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/TritonGPUToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/TritonGPUToLLVM.cpp @@ -68,6 +68,10 @@ struct ConvertTritonGPUToLLVM : public triton::gpu::intel::impl::ConvertTritonIntelGPUToLLVMBase< ConvertTritonGPUToLLVM> { using ConvertTritonIntelGPUToLLVMBase::ConvertTritonIntelGPUToLLVMBase; + ConvertTritonGPUToLLVM() = default; + ConvertTritonGPUToLLVM(bool advancedPath) { + this->advancedPath = advancedPath; + } void getDependentDialects(DialectRegistry ®istry) const override { registry.inserthasAttr(triton::gpu::intel::TritonIntelGPUDialect:: getSupportSG2DBlockAttrName()) && mod->hasAttr(triton::gpu::intel::TritonIntelGPUDialect:: getSupportDPASAttrName()) && - mlir::triton::tools::getBoolEnv("TRITON_INTEL_ADVANCED_PATH"); + (mlir::triton::tools::getBoolEnv("TRITON_INTEL_ADVANCED_PATH") || + advancedPath); + mlir::triton::intel::TritonGPUToLLVMPipelineManager pipelineManager( + mod, context, isAdvancedPathEnabled); + mlir::LowerToLLVMOptions option(context); mlir::triton::intel::TargetInfo targetInfo; TritonIntelGPUToLLVMTypeConverter typeConverter(context, option, targetInfo, isAdvancedPathEnabled); diff --git a/third_party/intel/triton_xpu.cc b/third_party/intel/triton_xpu.cc index 55db149919..4134d70221 100644 --- a/third_party/intel/triton_xpu.cc +++ b/third_party/intel/triton_xpu.cc @@ -67,8 +67,8 @@ void init_triton_intel_passes_ttir(py::module &&m) { } void init_triton_intel_passes_ttgpuir(py::module &&m) { - ADD_PASS_WRAPPER_0("add_to_llvmir", - gpu::intel::createConvertTritonIntelGPUToLLVM); + ADD_PASS_WRAPPER_OPT_1("add_to_llvmir", + gpu::intel::createConvertTritonIntelGPUToLLVM, bool); ADD_PASS_WRAPPER_0("add_accelerate_matmul", gpu::intel::createTritonIntelGPUAccelerateMatmul); ADD_PASS_WRAPPER_0("add_decompose_unsupported_conversions", From c3d17437849351ac77cfff6ad7d196005ebb863b Mon Sep 17 00:00:00 2001 From: Dewei Wang Date: Sun, 17 Nov 2024 18:39:29 -0800 Subject: [PATCH 2/5] fix format --- .../triton_kernels_benchmark/flash_attention_fwd_benchmark.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py b/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py index 73668bfb2e..5e2a47b2bc 100644 --- a/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py @@ -214,7 +214,7 @@ def forward(q, k, v, causal, sm_scale): num_warps=num_warps, # num_stages=num_stages, # grf_mode='large', # - advanced_path=True, + advanced_path=True, # ) return o From e8996c5903e7c7b25d2718c05a95df1208b67a05 Mon Sep 17 00:00:00 2001 From: Dewei Wang Date: Wed, 20 Nov 2024 17:46:22 -0800 Subject: [PATCH 3/5] fix review comments --- third_party/intel/backend/compiler.py | 5 +++-- .../lib/TritonIntelGPUToLLVM/TritonGPUToLLVM.cpp | 14 ++++++++------ 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/third_party/intel/backend/compiler.py b/third_party/intel/backend/compiler.py index a9e6e78c83..afb0aafe1a 100644 --- a/third_party/intel/backend/compiler.py +++ b/third_party/intel/backend/compiler.py @@ -233,8 +233,9 @@ def make_ttgir(mod, metadata, opt, properties): pm = ir.pass_manager(mod.context) pm.enable_debug() - if (properties["has_subgroup_2d_block_io"] and properties["has_subgroup_matrix_multiply_accumulate"] - and (os.getenv("TRITON_INTEL_ADVANCED_PATH", "0") == "1" or opt.advanced_path)): + if (os.getenv("TRITON_INTEL_ADVANCED_PATH", "0") == "1" or opt.advanced_path): + assert properties["has_subgroup_2d_block_io"] and properties["has_subgroup_matrix_multiply_accumulate"], \ + "Target do not support blocked load/mma" return XPUBackend.AdvancedPath.make_ttgir(mod, metadata, opt) passes.ttir.add_convert_to_ttgpuir(pm, "xpu", opt.num_warps, opt.threads_per_warp, opt.num_ctas) diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/TritonGPUToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/TritonGPUToLLVM.cpp index cce4829e0e..0cd3d582f8 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/TritonGPUToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/TritonGPUToLLVM.cpp @@ -83,12 +83,14 @@ struct ConvertTritonGPUToLLVM ModuleOp mod = getOperation(); bool isAdvancedPathEnabled = - mod->hasAttr(triton::gpu::intel::TritonIntelGPUDialect:: - getSupportSG2DBlockAttrName()) && - mod->hasAttr(triton::gpu::intel::TritonIntelGPUDialect:: - getSupportDPASAttrName()) && - (mlir::triton::tools::getBoolEnv("TRITON_INTEL_ADVANCED_PATH") || - advancedPath); + mlir::triton::tools::getBoolEnv("TRITON_INTEL_ADVANCED_PATH") || + advancedPath; + if (isAdvancedPathEnabled) + assert(mod->hasAttr(triton::gpu::intel::TritonIntelGPUDialect:: + getSupportSG2DBlockAttrName()) && + mod->hasAttr(triton::gpu::intel::TritonIntelGPUDialect:: + getSupportDPASAttrName()) && + "Target do not support blocked load/mma"); mlir::triton::intel::TritonGPUToLLVMPipelineManager pipelineManager( mod, context, isAdvancedPathEnabled); mlir::LowerToLLVMOptions option(context); From 8e68a0a290f770067efda0cfd9f0d052f028e4dc Mon Sep 17 00:00:00 2001 From: Dewei Wang Date: Wed, 20 Nov 2024 18:52:50 -0800 Subject: [PATCH 4/5] fix format --- third_party/intel/backend/compiler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/intel/backend/compiler.py b/third_party/intel/backend/compiler.py index afb0aafe1a..8dc657cbf9 100644 --- a/third_party/intel/backend/compiler.py +++ b/third_party/intel/backend/compiler.py @@ -234,8 +234,8 @@ def make_ttgir(mod, metadata, opt, properties): pm.enable_debug() if (os.getenv("TRITON_INTEL_ADVANCED_PATH", "0") == "1" or opt.advanced_path): - assert properties["has_subgroup_2d_block_io"] and properties["has_subgroup_matrix_multiply_accumulate"], \ - "Target do not support blocked load/mma" + if not (properties["has_subgroup_2d_block_io"] and properties["has_subgroup_matrix_multiply_accumulate"]): + raise AssertionError("Target do not support blocked load/mma") return XPUBackend.AdvancedPath.make_ttgir(mod, metadata, opt) passes.ttir.add_convert_to_ttgpuir(pm, "xpu", opt.num_warps, opt.threads_per_warp, opt.num_ctas) From 300129762a90ff3786f04a5b3d7325697ea68ed2 Mon Sep 17 00:00:00 2001 From: Dewei Wang Date: Thu, 21 Nov 2024 17:26:12 -0800 Subject: [PATCH 5/5] keep the original way --- third_party/intel/backend/compiler.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/third_party/intel/backend/compiler.py b/third_party/intel/backend/compiler.py index 8dc657cbf9..a9e6e78c83 100644 --- a/third_party/intel/backend/compiler.py +++ b/third_party/intel/backend/compiler.py @@ -233,9 +233,8 @@ def make_ttgir(mod, metadata, opt, properties): pm = ir.pass_manager(mod.context) pm.enable_debug() - if (os.getenv("TRITON_INTEL_ADVANCED_PATH", "0") == "1" or opt.advanced_path): - if not (properties["has_subgroup_2d_block_io"] and properties["has_subgroup_matrix_multiply_accumulate"]): - raise AssertionError("Target do not support blocked load/mma") + if (properties["has_subgroup_2d_block_io"] and properties["has_subgroup_matrix_multiply_accumulate"] + and (os.getenv("TRITON_INTEL_ADVANCED_PATH", "0") == "1" or opt.advanced_path)): return XPUBackend.AdvancedPath.make_ttgir(mod, metadata, opt) passes.ttir.add_convert_to_ttgpuir(pm, "xpu", opt.num_warps, opt.threads_per_warp, opt.num_ctas)