Skip to content

Commit c427499

Browse files
committed
add xpu option to enable advanced path
1 parent e8b34a0 commit c427499

File tree

6 files changed

+22
-16
lines changed

6 files changed

+22
-16
lines changed

benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,7 @@ def forward(q, k, v, causal, sm_scale):
214214
num_warps=num_warps, #
215215
num_stages=num_stages, #
216216
grf_mode='large', #
217+
advanced_path=True,
217218
)
218219
return o
219220

third_party/intel/backend/compiler.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ class XPUOptions:
5656
backend_name: str = 'intel'
5757
sanitize_overflow: bool = False
5858
generate_native_code: bool = False
59+
advanced_path: bool = False
5960

6061
def __post_init__(self):
6162
default_libdir = Path(__file__).parent / 'lib'
@@ -233,7 +234,7 @@ def make_ttgir(mod, metadata, opt, properties):
233234
pm.enable_debug()
234235

235236
if (properties["has_subgroup_2d_block_io"] and properties["has_subgroup_matrix_multiply_accumulate"]
236-
and os.getenv("TRITON_INTEL_ADVANCED_PATH", "0") == "1"):
237+
and (os.getenv("TRITON_INTEL_ADVANCED_PATH", "0") == "1" or opt.advanced_path)):
237238
return XPUBackend.AdvancedPath.make_ttgir(mod, metadata, opt)
238239

239240
passes.ttir.add_convert_to_ttgpuir(pm, "xpu", opt.num_warps, opt.threads_per_warp, opt.num_ctas)
@@ -291,7 +292,7 @@ def make_llir(src, metadata, options):
291292
# being used, e.g., convert_layout.
292293
if os.getenv("TRITON_INTEL_REDUCE_TRANSPOSE", "0") != "1":
293294
intel.passes.ttgpuir.add_allocate_shared_memory(pm)
294-
intel.passes.ttgpuir.add_to_llvmir(pm)
295+
intel.passes.ttgpuir.add_to_llvmir(pm, options.advanced_path)
295296
intel.set_fast_math(mod)
296297
passes.convert.add_arith_to_llvmir(pm)
297298
passes.common.add_canonicalizer(pm)

third_party/intel/include/TritonIntelGPUToLLVM/Passes.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,11 @@ def ConvertTritonIntelGPUToLLVM
2727
"mlir::triton::TritonDialect",
2828
"mlir::triton::gpu::TritonGPUDialect",
2929
"mlir::triton::TritonGEN::TritonGENDialect"];
30+
let options = [
31+
Option<"advancedPath", "advanced_path",
32+
"bool", /*default*/"false",
33+
"enable advanced path">,
34+
];
3035
}
3136

3237
#endif // TRITONINTELGPU_CONVERSION_PASSES

third_party/intel/lib/TritonIntelGPUToLLVM/PipelineManager.h

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -180,14 +180,8 @@ struct AddSPIRVEnvPattern : public mlir::OpRewritePattern<ModuleOp> {
180180
/// block pointers or not.
181181
class TritonGPUToLLVMPipelineManager {
182182
public:
183-
TritonGPUToLLVMPipelineManager(ModuleOp &mod, MLIRContext *ctx)
184-
: mod(mod), ctx(ctx),
185-
isAdvancedPathEnabled(
186-
mod->hasAttr(gpu::intel::TritonIntelGPUDialect::
187-
getSupportSG2DBlockAttrName()) &&
188-
mod->hasAttr(
189-
gpu::intel::TritonIntelGPUDialect::getSupportDPASAttrName()) &&
190-
mlir::triton::tools::getBoolEnv("TRITON_INTEL_ADVANCED_PATH")) {}
183+
TritonGPUToLLVMPipelineManager(ModuleOp &mod, MLIRContext *ctx, bool advanced)
184+
: mod(mod), ctx(ctx), isAdvancedPathEnabled(advanced) {}
191185

192186
/// FIXME: remove once the block ptr conversion path is capable of handling
193187
/// shared memory.

third_party/intel/lib/TritonIntelGPUToLLVM/TritonGPUToLLVM.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,10 @@ struct ConvertTritonGPUToLLVM
6868
: public triton::gpu::intel::impl::ConvertTritonIntelGPUToLLVMBase<
6969
ConvertTritonGPUToLLVM> {
7070
using ConvertTritonIntelGPUToLLVMBase::ConvertTritonIntelGPUToLLVMBase;
71+
ConvertTritonGPUToLLVM() = default;
72+
ConvertTritonGPUToLLVM(bool advancedPath) {
73+
this->advancedPath = advancedPath;
74+
}
7175

7276
void getDependentDialects(DialectRegistry &registry) const override {
7377
registry.insert<LLVM::LLVMDialect, TritonGEN::TritonGENDialect,
@@ -78,15 +82,16 @@ struct ConvertTritonGPUToLLVM
7882
MLIRContext *context = &getContext();
7983
ModuleOp mod = getOperation();
8084

81-
mlir::triton::intel::TritonGPUToLLVMPipelineManager pipelineManager(
82-
mod, context);
83-
mlir::LowerToLLVMOptions option(context);
8485
bool isAdvancedPathEnabled =
8586
mod->hasAttr(triton::gpu::intel::TritonIntelGPUDialect::
8687
getSupportSG2DBlockAttrName()) &&
8788
mod->hasAttr(triton::gpu::intel::TritonIntelGPUDialect::
8889
getSupportDPASAttrName()) &&
89-
mlir::triton::tools::getBoolEnv("TRITON_INTEL_ADVANCED_PATH");
90+
(mlir::triton::tools::getBoolEnv("TRITON_INTEL_ADVANCED_PATH") ||
91+
advancedPath);
92+
mlir::triton::intel::TritonGPUToLLVMPipelineManager pipelineManager(
93+
mod, context, isAdvancedPathEnabled);
94+
mlir::LowerToLLVMOptions option(context);
9095
mlir::triton::intel::TargetInfo targetInfo;
9196
TritonIntelGPUToLLVMTypeConverter typeConverter(context, option, targetInfo,
9297
isAdvancedPathEnabled);

third_party/intel/triton_xpu.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,8 @@ void init_triton_intel_passes_ttir(py::module &&m) {
6767
}
6868

6969
void init_triton_intel_passes_ttgpuir(py::module &&m) {
70-
ADD_PASS_WRAPPER_0("add_to_llvmir",
71-
gpu::intel::createConvertTritonIntelGPUToLLVM);
70+
ADD_PASS_WRAPPER_OPT_1("add_to_llvmir",
71+
gpu::intel::createConvertTritonIntelGPUToLLVM, bool);
7272
ADD_PASS_WRAPPER_0("add_accelerate_matmul",
7373
gpu::intel::createTritonIntelGPUAccelerateMatmul);
7474
ADD_PASS_WRAPPER_0("add_decompose_unsupported_conversions",

0 commit comments

Comments
 (0)