Skip to content

Commit 4e98a9c

Browse files
add xpu option to enable advanced path (#2732)
1 parent 816d7ef commit 4e98a9c

File tree

6 files changed

+27
-19
lines changed

6 files changed

+27
-19
lines changed

benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,7 @@ def forward(q, k, v, causal, sm_scale):
214214
num_warps=num_warps, #
215215
num_stages=num_stages, #
216216
grf_mode='large', #
217+
advanced_path=True, #
217218
)
218219
return o
219220

third_party/intel/backend/compiler.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ class XPUOptions:
5656
backend_name: str = 'intel'
5757
sanitize_overflow: bool = False
5858
generate_native_code: bool = False
59+
advanced_path: bool = False
5960

6061
def __post_init__(self):
6162
default_libdir = Path(__file__).parent / 'lib'
@@ -233,7 +234,7 @@ def make_ttgir(mod, metadata, opt, properties):
233234
pm.enable_debug()
234235

235236
if (properties["has_subgroup_2d_block_io"] and properties["has_subgroup_matrix_multiply_accumulate"]
236-
and os.getenv("TRITON_INTEL_ADVANCED_PATH", "0") == "1"):
237+
and (os.getenv("TRITON_INTEL_ADVANCED_PATH", "0") == "1" or opt.advanced_path)):
237238
return XPUBackend.AdvancedPath.make_ttgir(mod, metadata, opt)
238239

239240
passes.ttir.add_convert_to_ttgpuir(pm, "xpu", opt.num_warps, opt.threads_per_warp, opt.num_ctas)
@@ -292,7 +293,7 @@ def make_llir(src, metadata, options):
292293
# being used, e.g., convert_layout.
293294
if os.getenv("TRITON_INTEL_REDUCE_TRANSPOSE", "0") != "1":
294295
intel.passes.ttgpuir.add_allocate_shared_memory(pm)
295-
intel.passes.ttgpuir.add_to_llvmir(pm)
296+
intel.passes.ttgpuir.add_to_llvmir(pm, options.advanced_path)
296297
intel.set_fast_math(mod)
297298
passes.convert.add_arith_to_llvmir(pm)
298299
passes.common.add_canonicalizer(pm)

third_party/intel/include/TritonIntelGPUToLLVM/Passes.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,11 @@ def ConvertTritonIntelGPUToLLVM
2727
"mlir::triton::TritonDialect",
2828
"mlir::triton::gpu::TritonGPUDialect",
2929
"mlir::triton::TritonGEN::TritonGENDialect"];
30+
let options = [
31+
Option<"advancedPath", "advanced_path",
32+
"bool", /*default*/"false",
33+
"enable advanced path">,
34+
];
3035
}
3136

3237
#endif // TRITONINTELGPU_CONVERSION_PASSES

third_party/intel/lib/TritonIntelGPUToLLVM/PipelineManager.h

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -181,14 +181,8 @@ struct AddSPIRVEnvPattern : public mlir::OpRewritePattern<ModuleOp> {
181181
/// block pointers or not.
182182
class TritonGPUToLLVMPipelineManager {
183183
public:
184-
TritonGPUToLLVMPipelineManager(ModuleOp &mod, MLIRContext *ctx)
185-
: mod(mod), ctx(ctx),
186-
isAdvancedPathEnabled(
187-
mod->hasAttr(gpu::intel::TritonIntelGPUDialect::
188-
getSupportSG2DBlockAttrName()) &&
189-
mod->hasAttr(
190-
gpu::intel::TritonIntelGPUDialect::getSupportDPASAttrName()) &&
191-
mlir::triton::tools::getBoolEnv("TRITON_INTEL_ADVANCED_PATH")) {}
184+
TritonGPUToLLVMPipelineManager(ModuleOp &mod, MLIRContext *ctx, bool advanced)
185+
: mod(mod), ctx(ctx), isAdvancedPathEnabled(advanced) {}
192186

193187
/// FIXME: remove once the block ptr conversion path is capable of handling
194188
/// shared memory.

third_party/intel/lib/TritonIntelGPUToLLVM/TritonGPUToLLVM.cpp

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,10 @@ struct ConvertTritonGPUToLLVM
6868
: public triton::gpu::intel::impl::ConvertTritonIntelGPUToLLVMBase<
6969
ConvertTritonGPUToLLVM> {
7070
using ConvertTritonIntelGPUToLLVMBase::ConvertTritonIntelGPUToLLVMBase;
71+
ConvertTritonGPUToLLVM() = default;
72+
ConvertTritonGPUToLLVM(bool advancedPath) {
73+
this->advancedPath = advancedPath;
74+
}
7175

7276
void getDependentDialects(DialectRegistry &registry) const override {
7377
registry.insert<LLVM::LLVMDialect, TritonGEN::TritonGENDialect,
@@ -78,15 +82,18 @@ struct ConvertTritonGPUToLLVM
7882
MLIRContext *context = &getContext();
7983
ModuleOp mod = getOperation();
8084

85+
bool isAdvancedPathEnabled =
86+
mlir::triton::tools::getBoolEnv("TRITON_INTEL_ADVANCED_PATH") ||
87+
advancedPath;
88+
if (isAdvancedPathEnabled)
89+
assert(mod->hasAttr(triton::gpu::intel::TritonIntelGPUDialect::
90+
getSupportSG2DBlockAttrName()) &&
91+
mod->hasAttr(triton::gpu::intel::TritonIntelGPUDialect::
92+
getSupportDPASAttrName()) &&
93+
"Target do not support blocked load/mma");
8194
mlir::triton::intel::TritonGPUToLLVMPipelineManager pipelineManager(
82-
mod, context);
95+
mod, context, isAdvancedPathEnabled);
8396
mlir::LowerToLLVMOptions option(context);
84-
bool isAdvancedPathEnabled =
85-
mod->hasAttr(triton::gpu::intel::TritonIntelGPUDialect::
86-
getSupportSG2DBlockAttrName()) &&
87-
mod->hasAttr(triton::gpu::intel::TritonIntelGPUDialect::
88-
getSupportDPASAttrName()) &&
89-
mlir::triton::tools::getBoolEnv("TRITON_INTEL_ADVANCED_PATH");
9097
mlir::triton::intel::TargetInfo targetInfo;
9198
TritonIntelGPUToLLVMTypeConverter typeConverter(context, option, targetInfo,
9299
isAdvancedPathEnabled);

third_party/intel/triton_xpu.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,8 @@ void init_triton_intel_passes_ttir(py::module &&m) {
6767
}
6868

6969
void init_triton_intel_passes_ttgpuir(py::module &&m) {
70-
ADD_PASS_WRAPPER_0("add_to_llvmir",
71-
gpu::intel::createConvertTritonIntelGPUToLLVM);
70+
ADD_PASS_WRAPPER_OPT_1("add_to_llvmir",
71+
gpu::intel::createConvertTritonIntelGPUToLLVM, bool);
7272
ADD_PASS_WRAPPER_0("add_accelerate_matmul",
7373
gpu::intel::createTritonIntelGPUAccelerateMatmul);
7474
ADD_PASS_WRAPPER_0("add_decompose_unsupported_conversions",

0 commit comments

Comments
 (0)