Skip to content

Commit fc4c5d2

Browse files
committed
add xpu option to enable advanced path
1 parent e8b34a0 commit fc4c5d2

File tree

6 files changed

+22
-14
lines changed

6 files changed

+22
-14
lines changed

benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,7 @@ def forward(q, k, v, causal, sm_scale):
214214
num_warps=num_warps, #
215215
num_stages=num_stages, #
216216
grf_mode='large', #
217+
advanced_path=True,
217218
)
218219
return o
219220

third_party/intel/backend/compiler.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def _path_to_binary(binary: str):
3636

3737
@dataclass
3838
class XPUOptions:
39+
advanced_path: bool = False
3940
num_warps: int = 4
4041
num_ctas: int = 1
4142
num_stages: int = 2
@@ -232,6 +233,8 @@ def make_ttgir(mod, metadata, opt, properties):
232233
pm = ir.pass_manager(mod.context)
233234
pm.enable_debug()
234235

236+
if (opt.advanced_path):
237+
return XPUBackend.AdvancedPath.make_ttgir(mod, metadata, opt)
235238
if (properties["has_subgroup_2d_block_io"] and properties["has_subgroup_matrix_multiply_accumulate"]
236239
and os.getenv("TRITON_INTEL_ADVANCED_PATH", "0") == "1"):
237240
return XPUBackend.AdvancedPath.make_ttgir(mod, metadata, opt)
@@ -291,7 +294,7 @@ def make_llir(src, metadata, options):
291294
# being used, e.g., convert_layout.
292295
if os.getenv("TRITON_INTEL_REDUCE_TRANSPOSE", "0") != "1":
293296
intel.passes.ttgpuir.add_allocate_shared_memory(pm)
294-
intel.passes.ttgpuir.add_to_llvmir(pm)
297+
intel.passes.ttgpuir.add_to_llvmir(pm, options.advanced_path)
295298
intel.set_fast_math(mod)
296299
passes.convert.add_arith_to_llvmir(pm)
297300
passes.common.add_canonicalizer(pm)

third_party/intel/include/TritonIntelGPUToLLVM/Passes.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,11 @@ def ConvertTritonIntelGPUToLLVM
2727
"mlir::triton::TritonDialect",
2828
"mlir::triton::gpu::TritonGPUDialect",
2929
"mlir::triton::TritonGEN::TritonGENDialect"];
30+
let options = [
31+
Option<"advancedPath", "advanced_path",
32+
"bool", /*default*/"false",
33+
"enable advanced path">,
34+
];
3035
}
3136

3237
#endif // TRITONINTELGPU_CONVERSION_PASSES

third_party/intel/lib/TritonIntelGPUToLLVM/PipelineManager.h

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -180,14 +180,8 @@ struct AddSPIRVEnvPattern : public mlir::OpRewritePattern<ModuleOp> {
180180
/// block pointers or not.
181181
class TritonGPUToLLVMPipelineManager {
182182
public:
183-
TritonGPUToLLVMPipelineManager(ModuleOp &mod, MLIRContext *ctx)
184-
: mod(mod), ctx(ctx),
185-
isAdvancedPathEnabled(
186-
mod->hasAttr(gpu::intel::TritonIntelGPUDialect::
187-
getSupportSG2DBlockAttrName()) &&
188-
mod->hasAttr(
189-
gpu::intel::TritonIntelGPUDialect::getSupportDPASAttrName()) &&
190-
mlir::triton::tools::getBoolEnv("TRITON_INTEL_ADVANCED_PATH")) {}
183+
TritonGPUToLLVMPipelineManager(ModuleOp &mod, MLIRContext *ctx, bool advanced)
184+
: mod(mod), ctx(ctx), isAdvancedPathEnabled(advanced) {}
191185

192186
/// FIXME: remove once the block ptr conversion path is capable of handling
193187
/// shared memory.

third_party/intel/lib/TritonIntelGPUToLLVM/TritonGPUToLLVM.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,10 @@ struct ConvertTritonGPUToLLVM
6868
: public triton::gpu::intel::impl::ConvertTritonIntelGPUToLLVMBase<
6969
ConvertTritonGPUToLLVM> {
7070
using ConvertTritonIntelGPUToLLVMBase::ConvertTritonIntelGPUToLLVMBase;
71+
ConvertTritonGPUToLLVM() = default;
72+
ConvertTritonGPUToLLVM(bool advancedPath) {
73+
this->advancedPath = advancedPath;
74+
}
7175

7276
void getDependentDialects(DialectRegistry &registry) const override {
7377
registry.insert<LLVM::LLVMDialect, TritonGEN::TritonGENDialect,
@@ -78,15 +82,16 @@ struct ConvertTritonGPUToLLVM
7882
MLIRContext *context = &getContext();
7983
ModuleOp mod = getOperation();
8084

81-
mlir::triton::intel::TritonGPUToLLVMPipelineManager pipelineManager(
82-
mod, context);
83-
mlir::LowerToLLVMOptions option(context);
8485
bool isAdvancedPathEnabled =
8586
mod->hasAttr(triton::gpu::intel::TritonIntelGPUDialect::
8687
getSupportSG2DBlockAttrName()) &&
8788
mod->hasAttr(triton::gpu::intel::TritonIntelGPUDialect::
8889
getSupportDPASAttrName()) &&
8990
mlir::triton::tools::getBoolEnv("TRITON_INTEL_ADVANCED_PATH");
91+
isAdvancedPathEnabled |= advancedPath;
92+
mlir::triton::intel::TritonGPUToLLVMPipelineManager pipelineManager(
93+
mod, context, isAdvancedPathEnabled);
94+
mlir::LowerToLLVMOptions option(context);
9095
mlir::triton::intel::TargetInfo targetInfo;
9196
TritonIntelGPUToLLVMTypeConverter typeConverter(context, option, targetInfo,
9297
isAdvancedPathEnabled);

third_party/intel/triton_xpu.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,8 @@ void init_triton_intel_passes_ttir(py::module &&m) {
6767
}
6868

6969
void init_triton_intel_passes_ttgpuir(py::module &&m) {
70-
ADD_PASS_WRAPPER_0("add_to_llvmir",
71-
gpu::intel::createConvertTritonIntelGPUToLLVM);
70+
ADD_PASS_WRAPPER_OPT_1("add_to_llvmir",
71+
gpu::intel::createConvertTritonIntelGPUToLLVM, bool);
7272
ADD_PASS_WRAPPER_0("add_accelerate_matmul",
7373
gpu::intel::createTritonIntelGPUAccelerateMatmul);
7474
ADD_PASS_WRAPPER_0("add_decompose_unsupported_conversions",

0 commit comments

Comments
 (0)