Skip to content

Commit fff08ef

Browse files
authored
Use module attribute to specify target arch. (#3387)
This variant makes the target arch attribute mandatory for conversion to the LLVM dialect. We can fall back to the optional attribute if it looks too intrusive. Signed-off-by: Ilya Enkovich <[email protected]>
1 parent 8b1aee9 commit fff08ef

File tree

9 files changed

+41
-12
lines changed

9 files changed

+41
-12
lines changed

test/TritonIntelGPU/triton_annotate_module.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
module {
44
// COM: Ensure that the 'threads-per-warp' attribute is set according to the option.
5-
// CHECK: module attributes {triton_intel_gpu.min_sg_size = 16 : i32, triton_intel_gpu.support_dpas, triton_intel_gpu.support_sg_2d_block, "ttg.threads-per-warp" = 32 : i32}
5+
// CHECK: module attributes {triton_intel_gpu.min_sg_size = 16 : i32, triton_intel_gpu.support_dpas, triton_intel_gpu.support_sg_2d_block, triton_intel_gpu.target_arch = "spir64", "ttg.threads-per-warp" = 32 : i32}
66
tt.func @kernel() {
77
tt.return
88
}
@@ -13,7 +13,7 @@ module {
1313
module {
1414
// COM: Ensure that the 'threads-per-warp' attribute is overwritten when the kernel contains a 'tt.dot'
1515
// operation that can be lowered to DPAS instructions.
16-
// CHECK: module attributes {triton_intel_gpu.min_sg_size = 16 : i32, triton_intel_gpu.support_dpas, triton_intel_gpu.support_sg_2d_block, "ttg.threads-per-warp" = 16 : i32}
16+
// CHECK: module attributes {triton_intel_gpu.min_sg_size = 16 : i32, triton_intel_gpu.support_dpas, triton_intel_gpu.support_sg_2d_block, triton_intel_gpu.target_arch = "spir64", "ttg.threads-per-warp" = 16 : i32}
1717
tt.func @kernel() {
1818
%a = arith.constant dense<1.00e+00> : tensor<128x32xf16>
1919
%b = arith.constant dense<2.00e+00> : tensor<32x128xf16>

third_party/intel/backend/compiler.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,10 +252,12 @@ def make_ttgir(mod, metadata, opt, properties):
252252
# Annotate module with information required by subsequent transformations.
253253
pm = ir.pass_manager(mod.context)
254254
pm.enable_debug()
255+
target_arch = "spir64"
255256
intel.passes.ttgpuir.add_triton_annotate_module(pm, min(properties["sub_group_sizes"]),
256257
properties["has_subgroup_2d_block_io"],
257258
properties["has_subgroup_matrix_multiply_accumulate"],
258-
properties["has_bfloat16_conversions"], opt.threads_per_warp)
259+
properties["has_bfloat16_conversions"], opt.threads_per_warp,
260+
target_arch)
259261
pm.run(mod)
260262

261263
# Overwrite the threads_per_warp option with the module annotation.

third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUDialect.td

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,13 @@ def TritonIntelGPU_Dialect : Dialect {
4848
static constexpr llvm::StringRef getBlockIOAttrName() {
4949
return "triton_intel_gpu.block_io";
5050
}
51+
52+
/// Get the name of the attribute used to specify the target architecture. This
53+
/// attribute matches architecture in a target triple used for the resulting LLVM
54+
/// IR module.
55+
static constexpr llvm::StringRef getTargetArchAttrName() {
56+
return "triton_intel_gpu.target_arch";
57+
}
5158
}];
5259

5360
let useDefaultAttributePrinterParser = 1;

third_party/intel/include/Dialect/TritonIntelGPU/IR/Utils.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#define TRITON_DIALECT_TRITON_INTEL_GPU_IR_UTILS_H
1111

1212
#include "intel/include/Analysis/AxisInfo.h"
13+
#include "intel/include/Dialect/TritonIntelGPU/IR/Dialect.h"
1314
#include "mlir/IR/Operation.h"
1415
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
1516
#include <triton/Tools/Sys/GetEnv.hpp>
@@ -46,6 +47,16 @@ inline unsigned getNumElementsPerThread(
4647
inline bool applyTransposedReduction() {
4748
return tools::getBoolEnv("TRITON_INTEL_REDUCE_TRANSPOSE");
4849
}
50+
51+
// Check if module's target arch is SPIRV. If there is no target arch
52+
// attribute, then we assume SPIRV target by default.
53+
inline bool hasSpirvTargetArch(Operation *op) {
54+
if (!isa<ModuleOp>(op))
55+
op = op->getParentOfType<ModuleOp>();
56+
auto arch = op->getAttrOfType<StringAttr>(
57+
triton::gpu::intel::TritonIntelGPUDialect::getTargetArchAttrName());
58+
return !arch || arch.str().substr(0, 4) == "spir";
59+
}
4960
} // namespace mlir::triton::gpu::intel
5061

5162
#endif // TRITON_DIALECT_TRITON_INTEL_GPU_IR_UTILS_H

third_party/intel/include/TritonAnnotateModule/Passes.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ def TritonAnnotateModule: Pass<"triton-annotate-module", "mlir::ModuleOp"> {
3636
Option<"threadsPerWarp", "threads-per-warp",
3737
"unsigned", /*default*/"32",
3838
"number of threads per warp (aka subgroup size)">,
39+
Option<"targetArch", "target-arch", "std::string", /*default*/"\"spir64\"",
40+
"target architecture name">
3941
];
4042
}
4143

third_party/intel/lib/TritonAnnotateModule/TritonAnnotateModule.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@ struct TritonAnnotateModule
3737
intel::TritonIntelGPUDialect::getSupportBF16ConversionAttrName(),
3838
builder.getUnitAttr());
3939

40+
mod->setAttr(intel::TritonIntelGPUDialect::getTargetArchAttrName(),
41+
builder.getStringAttr(targetArch));
42+
4043
DPASAnalysis &dpasAnalysis = getAnalysis<DPASAnalysis>();
4144
setThreadsPerWarp(mod, dpasAnalysis);
4245
}

third_party/intel/lib/TritonIntelGPUToLLVM/PipelineManager.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
2424
#include "mlir/IR/PatternMatch.h"
2525

26+
#include "intel/include/Dialect/TritonIntelGPU/IR/Utils.h"
2627
#include "intel/include/GPUToTritonGEN/GPUToTritonGENPass.h"
2728
#include "intel/include/TritonGENToLLVM/TritonGENToLLVMPass.h"
2829
#include "triton/Analysis/AxisInfo.h"
@@ -143,7 +144,7 @@ struct AddSPIRVEnvPattern : public mlir::OpRewritePattern<ModuleOp> {
143144

144145
LogicalResult matchAndRewrite(ModuleOp op,
145146
PatternRewriter &rewriter) const override {
146-
if (spirv::lookupTargetEnv(op)) {
147+
if (!gpu::intel::hasSpirvTargetArch(op) || spirv::lookupTargetEnv(op)) {
147148
return failure();
148149
}
149150

third_party/intel/lib/TritonIntelGPUToLLVM/TritonGPUToLLVM.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,10 @@ class TritonLLVMConversionTarget : public ConversionTarget {
5959
addIllegalDialect<triton::gpu::intel::TritonIntelGPUDialect>();
6060
addIllegalDialect<mlir::gpu::GPUDialect>();
6161
addLegalOp<mlir::UnrealizedConversionCastOp>();
62-
addDynamicallyLegalOp<ModuleOp>(
63-
[](ModuleOp op) { return spirv::lookupTargetEnv(op) != nullptr; });
62+
addDynamicallyLegalOp<ModuleOp>([](ModuleOp op) {
63+
return !triton::gpu::intel::hasSpirvTargetArch(op) ||
64+
spirv::lookupTargetEnv(op) != nullptr;
65+
});
6466
}
6567
};
6668

third_party/intel/triton_xpu.cc

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,11 @@ using ret = py::return_value_policy;
4747
m.def(name, [](mlir::PassManager &pm, ty0 val0, ty1 val1) { \
4848
pm.addPass(builder({val0, val1})); \
4949
})
50-
#define ADD_PASS_WRAPPER_OPT_5(name, builder, ty0, ty1, ty2, ty3, ty4) \
51-
m.def(name, \
52-
[](mlir::PassManager &pm, ty0 val0, ty1 val1, ty2 val2, ty3 val3, \
53-
ty4 val4) { pm.addPass(builder({val0, val1, val2, val3, val4})); })
50+
#define ADD_PASS_WRAPPER_OPT_6(name, builder, ty0, ty1, ty2, ty3, ty4, ty5) \
51+
m.def(name, [](mlir::PassManager &pm, ty0 val0, ty1 val1, ty2 val2, \
52+
ty3 val3, ty4 val4, ty5 val5) { \
53+
pm.addPass(builder({val0, val1, val2, val3, val4, val5})); \
54+
})
5455

5556
static uint32_t findKernels(llvm::Module &M,
5657
std::set<llvm::Function *> &functions) {
@@ -97,9 +98,9 @@ void init_triton_intel_passes_ttgpuir(py::module &&m) {
9798
gpu::intel::createTritonIntelGPUMatchTargetSize);
9899
ADD_PASS_WRAPPER_0("add_schedule_load",
99100
gpu::intel::createTritonIntelGPUScheduleLoad);
100-
ADD_PASS_WRAPPER_OPT_5("add_triton_annotate_module",
101+
ADD_PASS_WRAPPER_OPT_6("add_triton_annotate_module",
101102
gpu::intel::createTritonAnnotateModule, unsigned, bool,
102-
bool, bool, unsigned);
103+
bool, bool, unsigned, const std::string &);
103104
ADD_PASS_WRAPPER_0("add_reduce_data_duplication",
104105
gpu::intel::createTritonIntelGPUReduceDataDuplication);
105106
ADD_PASS_WRAPPER_0("add_materialize_block_pointer",

0 commit comments

Comments
 (0)