Skip to content

Commit 7375302

Browse files
authored
Add rewrite_stack_ptr post process pass (#3497)
This PR have following changes: 1. Discard adding `TargetInfo::getStackPointer` that only exist in this fork(then we keep the same as upstream), replace it with an intel pass per discussion in #3046 (comment). 2. Use public `ControlFlowOpToLLVM` and sync public `funcOpToLLVM` to `PipelineManger` (difficult to use the public `FuncOpToLLVM` directly as it contains Nv-specific codes) to match calling convention. 3. Lit tests
1 parent 5d68d95 commit 7375302

File tree

21 files changed

+322
-252
lines changed

21 files changed

+322
-252
lines changed

include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,9 +98,6 @@ class TargetInfoBase {
9898
virtual void storeOpAnnotation(triton::gpu::LocalStoreOp op,
9999
size_t localStoreOpCount, Type type) const {}
100100

101-
virtual Value getStackPointer(RewriterBase &rewriter,
102-
FunctionOpInterface funcOp) const = 0;
103-
104101
virtual ~TargetInfoBase() {}
105102
};
106103
} // namespace mlir::triton

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -583,7 +583,7 @@ inline Value getSharedMemoryBase(Location loc, RewriterBase &rewriter,
583583
auto b = TritonLLVMOpBuilder(loc, rewriter);
584584
Value offVal = b.i32_val(offset);
585585
Value base =
586-
b.gep(ptrTy, i8_ty, target.getStackPointer(rewriter, func), offVal);
586+
b.gep(ptrTy, i8_ty, LLVM::getStackPointer(rewriter, func), offVal);
587587
return base;
588588
}
589589

lib/Conversion/TritonGPUToLLVM/ControlFlowOpToLLVM.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ struct CallOpConversion : public ConvertOpToLLVMPattern<triton::CallOp> {
8585
callOp.getLoc(), /*opOperands=*/callOp->getOperands(),
8686
adaptor.getOperands(), rewriter);
8787
if (!caller->hasAttr("allocation.offset")) {
88-
auto base = targetInfo.getStackPointer(rewriter, caller);
88+
auto base = LLVM::getStackPointer(rewriter, caller);
8989
promotedOperands.push_back(base);
9090
} else {
9191
auto base = LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, callOp);

test/Conversion/intel/sub-group-transpose.mlir

Lines changed: 38 additions & 19 deletions
Large diffs are not rendered by default.

test/Conversion/intel/tritongpu_to_gen.mlir

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -603,6 +603,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
603603
// CHECK-LABEL: basic_alloc_tensor(%arg0: !llvm.ptr<3>)
604604
tt.func @basic_alloc_tensor() {
605605
// CHECK-NEXT: llvm.mlir.constant
606+
// CHECK-NEXT: llvm.mlir.addressof @global_smem
606607
// CHECK-NEXT: llvm.getelementptr
607608
%0 = ttg.local_alloc : () -> !ttg.memdesc<16x16xf16, #shared0, #smem, mutable>
608609
tt.return
@@ -1102,7 +1103,9 @@ module attributes {"ttg.target" = "xpu", "ttg.num-ctas" = 1 : i32, "ttg.num-warp
11021103
// CHECK-NEXT: llvm.br ^bb2([[CMPXCHG_RES]] : i32)
11031104
// CHECK-NEXT: ^bb2([[RES:%.*]]: i32):
11041105
// CHECK-NEXT: [[RES_CAST:%.*]] = llvm.bitcast [[RES]] : i32 to f32
1105-
// CHECK: [[GEP:%.*]] = llvm.getelementptr %arg3[{{.*}}] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8
1106+
// CHECK: [[C_0:%.*]] = llvm.mlir.constant(0 : i32) : i32
1107+
// CHECK: [[SMEM_0:%.*]] = llvm.mlir.addressof @global_smem : !llvm.ptr<3>
1108+
// CHECK: [[GEP:%.*]] = llvm.getelementptr [[SMEM_0]]{{\[}}[[C_0]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8
11061109
// CHECK-NEXT: [[GEP_CAST:%.*]] = llvm.bitcast [[GEP]] : !llvm.ptr<3> to !llvm.ptr<3>
11071110
// CHECK-NEXT: llvm.cond_br [[MASK]], ^bb3, ^bb4
11081111
// CHECK-NEXT: ^bb3:
@@ -1210,7 +1213,9 @@ module attributes {"ttg.target" = "xpu", "ttg.num-ctas" = 1 : i32, "ttg.num-warp
12101213
// CHECK-NEXT: llvm.br ^bb2([[RMW_RES]] : f32)
12111214
// CHECK-NEXT: ^bb2([[RMW_PHI:%.*]]: f32):
12121215
// CHECK-NEXT: [[RMW_CAST:%.*]] = llvm.bitcast [[RMW_PHI]] : f32 to f32
1213-
// CHECK: [[GEP:%.*]] = llvm.getelementptr %arg3[{{.*}}] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8
1216+
// CHECK: [[C_0:%.*]] = llvm.mlir.constant(0 : i32) : i32
1217+
// CHECK: [[SMEM_0:%.*]] = llvm.mlir.addressof @global_smem : !llvm.ptr<3>
1218+
// CHECK: [[GEP:%.*]] = llvm.getelementptr [[SMEM_0]]{{\[}}[[C_0]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8
12141219
// CHECK-NEXT: [[GEP_CAST:%.*]] = llvm.bitcast [[GEP]] : !llvm.ptr<3> to !llvm.ptr<3>
12151220
// CHECK-NEXT: llvm.cond_br [[PRED]], ^bb3, ^bb4
12161221
// CHECK-NEXT: ^bb3:
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
// RUN: triton-opt %s -split-input-file --convert-triton-intel-gpu-to-llvm --tritonintelgpu-rewrite-stack-ptr | FileCheck %s
2+
3+
module attributes {triton_intel_gpu.min_sg_size = 16 : i32, triton_intel_gpu.support_bf16_conversion, triton_intel_gpu.support_dpas, triton_intel_gpu.support_sg_2d_block, triton_intel_gpu.target_arch = "spir64", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shared = 0 : i32, ttg.target = "xpu", "ttg.threads-per-warp" = 32 : i32} {
4+
// CHECK-LABEL: llvm.mlir.global external @global_smem() {addr_space = 3 : i32, alignment = 16 : i64} : !llvm.array<0 x i8>
5+
// CHECK-LABEL: llvm.func spir_kernelcc @kernel(%arg0: !llvm.ptr<1>, %arg1: !llvm.ptr<1>, %arg2: !llvm.ptr<1>)
6+
tt.func public @kernel(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}) attributes {noinline = false} {
7+
%0 = tt.load %arg0 : !tt.ptr<f32>
8+
%1 = tt.load %arg1 : !tt.ptr<f32>
9+
// CHECK: llvm.mlir.poison : !llvm.ptr<3>
10+
// CHECK: llvm.call @noinline_simple_fn__fp32_fp32_Pfp32__(%8, %17, %arg2, %18, %arg2)
11+
tt.call @noinline_simple_fn__fp32_fp32_Pfp32__(%0, %1, %arg2) : (f32, f32, !tt.ptr<f32>) -> ()
12+
tt.return
13+
}
14+
// CHECK: llvm.func internal @noinline_simple_fn__fp32_fp32_Pfp32__(%arg0: f32, %arg1: f32, %arg2: !llvm.ptr<1>, %arg3: !llvm.ptr<3>, %arg4: !llvm.ptr<1>)
15+
tt.func private @noinline_simple_fn__fp32_fp32_Pfp32__(%arg0: f32 {tt.constancy = 1 : i64, tt.contiguity = 1 : i64, tt.divisibility = 1 : i64}, %arg1: f32 {tt.constancy = 1 : i64, tt.contiguity = 1 : i64, tt.divisibility = 1 : i64}, %arg2: !tt.ptr<f32> {tt.constancy = 1 : i64, tt.contiguity = 1 : i64, tt.divisibility = 16 : i64}) attributes {noinline = true} {
16+
%0 = arith.addf %arg0, %arg1 fastmath<fast> : f32
17+
tt.store %arg2, %0 : !tt.ptr<f32>
18+
tt.return
19+
}
20+
}
21+
22+
// -----
23+
24+
#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 4], warpsPerCTA = [1, 1], order = [1, 0]}>
25+
#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [1, 1], repCluster = [2, 1], A = [16, 8], B = [8, 16], C = [16, 16]}>
26+
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
27+
#smem = #ttg.shared_memory
28+
module attributes {triton_intel_gpu.min_sg_size = 16 : i32, triton_intel_gpu.support_bf16_conversion, triton_intel_gpu.support_dpas, triton_intel_gpu.support_sg_2d_block, triton_intel_gpu.target_arch = "spir64", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shared = 1280 : i32, ttg.target = "xpu", "ttg.threads-per-warp" = 16 : i32} {
29+
// CHECK-LABEL: llvm.mlir.global external @global_smem() {addr_space = 3 : i32, alignment = 16 : i64} : !llvm.array<0 x i8>
30+
// CHECK-LABEL: llvm.func spir_kernelcc @kernel(%arg0: !llvm.ptr<1>, %arg1: !llvm.ptr<1>, %arg2: !llvm.ptr<1>, %arg3: !llvm.ptr<3>)
31+
tt.func public @kernel(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}) attributes {noinline = false} {
32+
%0 = tt.load %arg0 : !tt.ptr<f32>
33+
%1 = tt.load %arg1 : !tt.ptr<f32>
34+
// CHECK: llvm.call @noinline_shared_fn__fp32_fp32_Pfp32__(%8, %17, %arg2, %arg3, %arg2)
35+
tt.call @noinline_shared_fn__fp32_fp32_Pfp32__(%0, %1, %arg2) {allocation.offset = 0 : i32} : (f32, f32, !tt.ptr<f32>) -> ()
36+
tt.return
37+
}
38+
// CHECK: llvm.func internal @noinline_shared_fn__fp32_fp32_Pfp32__(%arg0: f32, %arg1: f32, %arg2: !llvm.ptr<1>, %arg3: !llvm.ptr<3>, %arg4: !llvm.ptr<1>)
39+
// CHECK: llvm.getelementptr %arg3[{{.*}}]
40+
tt.func private @noinline_shared_fn__fp32_fp32_Pfp32__(%arg0: f32 {tt.constancy = 1 : i64, tt.contiguity = 1 : i64, tt.divisibility = 1 : i64}, %arg1: f32 {tt.constancy = 1 : i64, tt.contiguity = 1 : i64, tt.divisibility = 1 : i64}, %arg2: !tt.ptr<f32> {tt.constancy = 1 : i64, tt.contiguity = 1 : i64, tt.divisibility = 16 : i64}) attributes {noinline = true} {
41+
%cst = arith.constant dense<16> : tensor<16x1xi32, #blocked>
42+
%0 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked}>>
43+
%1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<16x1xi32, #blocked>
44+
%2 = arith.muli %1, %cst : tensor<16x1xi32, #blocked>
45+
%3 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>>
46+
%4 = tt.expand_dims %3 {axis = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x16xi32, #blocked>
47+
%5 = tt.broadcast %2 : tensor<16x1xi32, #blocked> -> tensor<16x16xi32, #blocked>
48+
%6 = tt.broadcast %4 : tensor<1x16xi32, #blocked> -> tensor<16x16xi32, #blocked>
49+
%7 = arith.addi %5, %6 : tensor<16x16xi32, #blocked>
50+
%8 = tt.splat %arg2 : !tt.ptr<f32> -> tensor<16x16x!tt.ptr<f32>, #blocked>
51+
%9 = tt.addptr %8, %7 : tensor<16x16x!tt.ptr<f32>, #blocked>, tensor<16x16xi32, #blocked>
52+
%10 = tt.load %9 : tensor<16x16x!tt.ptr<f32>, #blocked>
53+
%11 = ttg.local_alloc %10 {allocation.offset = 0 : i32} : (tensor<16x16xf32, #blocked>) -> !ttg.memdesc<16x16xf32, #shared, #smem>
54+
%12 = tt.splat %arg0 : f32 -> tensor<16x16xf32, #mma>
55+
%13 = ttg.local_load %11 : !ttg.memdesc<16x16xf32, #shared, #smem> -> tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>
56+
%14 = ttg.local_load %11 : !ttg.memdesc<16x16xf32, #shared, #smem> -> tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>>
57+
%15 = tt.dot %13, %14, %12, inputPrecision = tf32 : tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<16x16xf32, #mma>
58+
%16 = tt.splat %arg1 : f32 -> tensor<16x16xf32, #mma>
59+
%17 = arith.addf %15, %16 fastmath<fast> : tensor<16x16xf32, #mma>
60+
%18 = ttg.convert_layout %17 {allocation.offset = 0 : i32} : tensor<16x16xf32, #mma> -> tensor<16x16xf32, #blocked>
61+
tt.store %9, %18 : tensor<16x16x!tt.ptr<f32>, #blocked>
62+
tt.return
63+
}
64+
}

third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -447,19 +447,6 @@ void TargetInfo::assertFail(RewriterBase &rewriter, Location loc,
447447

448448
int TargetInfo::getSharedAddressSpace() const { return 3; }
449449

450-
Value TargetInfo::getStackPointer(RewriterBase &rewriter,
451-
FunctionOpInterface funcOp) const {
452-
// See NOTE: [Additional Function Arguments]
453-
if (!LLVM::isKernel(funcOp)) {
454-
return funcOp.getArgument(funcOp.getNumArguments() - 2);
455-
}
456-
457-
auto mod = funcOp->getParentOfType<ModuleOp>();
458-
auto globalBase = dyn_cast<LLVM::GlobalOp>(mod.lookupSymbol("global_smem"));
459-
assert(globalBase);
460-
return rewriter.create<LLVM::AddressOfOp>(funcOp.getLoc(), globalBase);
461-
}
462-
463450
int TargetInfo::getAddressSpace(Attribute addressSpace) const {
464451
int spaceId = 0;
465452
if (isa<triton::gpu::SharedMemorySpaceAttr>(addressSpace)) {

third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,6 @@ class TargetInfo : public mlir::triton::TargetInfoBase {
7575
void storeOpAnnotation(triton::gpu::LocalStoreOp op, size_t localStoreOpCount,
7676
Type type) const override;
7777

78-
Value getStackPointer(RewriterBase &rewriter,
79-
FunctionOpInterface funcOp) const override;
80-
8178
private:
8279
void printfImpl(Value formatStrStart, int formatStrByteCount, ValueRange args,
8380
RewriterBase &rewriter, bool useStdErr) const;

third_party/intel/backend/compiler.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -328,8 +328,9 @@ def make_llir(src, metadata, options):
328328
# solutions for SLM allocation, so this will crash on some operations
329329
# being used, e.g., convert_layout.
330330
if os.getenv("TRITON_INTEL_REDUCE_TRANSPOSE", "0") != "1":
331-
intel.passes.ttgpuir.add_allocate_shared_memory(pm)
331+
passes.ttgpuir.add_allocate_shared_memory(pm)
332332
intel.passes.ttgpuir.add_to_llvmir(pm, options.advanced_path, options.one_matrix_per_load_for_bt)
333+
intel.passes.ttgpuir.add_rewrite_stack_ptr(pm)
333334
intel.set_fast_math(mod)
334335
passes.convert.add_arith_to_llvmir(pm)
335336
passes.common.add_canonicalizer(pm)

third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Passes.td

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,4 +362,19 @@ tt.func @test(%arg0: tensor<16x32xf32, #mma>) -> tensor<16xf32, #ttg.slice<{dim
362362
"mlir::triton::gpu::TritonGPUDialect"];
363363
}
364364

365+
def TritonIntelGPURewriteStackPtr
366+
: Pass<"tritonintelgpu-rewrite-stack-ptr", "mlir::ModuleOp"> {
367+
let summary = "rewrite the getStackPointer for Intel by addressofOp replacement";
368+
369+
let description = [{
370+
This pass searches for the global_smem symbol and replaces the addressOfOp with a newly inserted
371+
SLM parameter or a PoisonOp to rewrite the getStackPointer for Intel.
372+
}];
373+
374+
let dependentDialects = [
375+
"mlir::triton::gpu::TritonGPUDialect",
376+
"mlir::triton::gpu::intel::TritonIntelGPUDialect", "mlir::scf::SCFDialect",
377+
"mlir::arith::ArithDialect"
378+
];
379+
}
365380
#endif // TRITON_INTEL_GPU_PASSES

0 commit comments

Comments
 (0)