Skip to content

Commit ba3ec66

Browse files
authored
[WS] reorder partition-loops and lower-aref (#7927)
* Reorder `partition-loops` and `lower-aref` passes * Split `lower-aref` to `lower-aref` and `assign-stage-phase` passes * the split separates concerns, as `assign-stage-phase` is more complex pass and testing/debugging can be focused on correctness stage/phase assignment w/o added complexity of `aref->mbarrier` lowering. * `assign-stage-phase` uses enterOp token to assign same stage variable that enterOp uses to exitOp, instead of previously having separate stage for enterOps/exitOps. * `lower-aref` testing verifies correctness of `aref->mbarrier` lowerings * in `load-mma-specialization` don't place final `waitOp` inside ws-region, revert to original behavior before triton-lang/triton#7757, as that change causes perf regression with this pR. Keeping ws.tag to differentiate partitions in different loops as it will be relied upon in `aref-tmem-insertion` (WIP). This is prep PR needed for `aref-tmem-insertion` WIP (will be submitted after this one) # New contributor declaration - [x] I am not making a trivial change, such as fixing a typo in a comment. - [x] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [x] I have added tests. - `/test` for `lit` tests - `/unittest` for C++ tests - `/python/test` for end-to-end tests - [ ] This PR does not need a test because `FILL THIS IN`. - Select one of the following. - [ ] I have not added any `lit` tests. - [x] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.)
1 parent 6576e50 commit ba3ec66

File tree

12 files changed

+829
-812
lines changed

12 files changed

+829
-812
lines changed

lib/Dialect/TritonGPU/Transforms/WarpSpecialization/AutomaticWarpSpecialization.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,9 @@ void AutomaticWarpSpecialization::runOnOperation() {
4242
// pm.addPass(arith::createIntRangeOptimizationsPass());
4343
pm.addPass(createSCCPPass());
4444
pm.addPass(createCSEPass());
45-
pm.addPass(createTritonGPUPartitionLoops());
45+
pm.addPass(createNVWSAssignStagePhase());
4646
pm.addPass(createNVWSLowerAref());
47+
pm.addPass(createTritonGPUPartitionLoops());
4748
pm.addPass(createNVWSLowerWarpGroup());
4849
if (failed(runPipeline(pm, getOperation())))
4950
return signalPassFailure();

lib/Dialect/TritonGPU/Transforms/WarpSpecialization/LoadMMASpecialization.cpp

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -811,16 +811,7 @@ static LogicalResult pipelineMMA(scf::ForOp &loop, PipelinedMMA &mma,
811811
Value lastIndex = loop.getResult(index.getArgNumber() - 1);
812812
Value lastPhase = loop.getResult(phase.getArgNumber() - 1);
813813
Value lastBar = createSingleBufferView(b, nodes.back().barNext, lastIndex);
814-
auto waitBarrierOp = b.create<ttng::WaitBarrierOp>(lastBar, lastPhase);
815-
auto node_front = nodes.front();
816-
auto partition = schedule.getPartition(inBody(node_front.op));
817-
PartitionBuilder b(waitBarrierOp->getLoc(), waitBarrierOp);
818-
lastBar.getDefiningOp()->setAttr(kWarpSpecializeTagAttrName,
819-
b.getI32IntegerAttr(schedule.getTag()));
820-
waitBarrierOp->setAttr(kWarpSpecializeTagAttrName,
821-
b.getI32IntegerAttr(schedule.getTag()));
822-
b.assignPartition(lastBar.getDefiningOp(), *partition);
823-
b.assignPartition(waitBarrierOp, *partition);
814+
b.create<ttng::WaitBarrierOp>(lastBar, lastPhase);
824815
}
825816

826817
llvm::SetVector<Operation *> predOps;

test/NVWS/assign_stage_phase.mlir

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
// RUN: triton-opt %s -split-input-file --allow-unregistered-dialect --nvws-assign-stage-phase -canonicalize | FileCheck %s
2+
3+
#shared0 = #ttg.nvmma_shared<{swizzlingByteWidth = 32, transposed = false, elementBitWidth = 16}>
4+
#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
5+
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}>
6+
#smem = #ttg.shared_memory
7+
module attributes {"ttg.num-warps" = 4 : i32} {
8+
9+
//CHECK-LABEL: @two_consumers
10+
tt.func @two_consumers(%arg0: i32, %arg1: i32, %arg2: i32) {
11+
// CHECK: [[C3:%.*]] = arith.constant 3 : i32
12+
// CHECK: [[C1:%.*]] = arith.constant 1 : i32
13+
// CHECK: [[C2:%.*]] = arith.constant 2 : i32
14+
// CHECK: [[C0:%.*]] = arith.constant 0 : i32
15+
%ub = arith.constant 4 : i32
16+
%c0_i32 = arith.constant 0 : i32
17+
%0 = ttg.local_alloc : () -> !ttg.memdesc<3x1xi32, #shared, #smem, mutable>
18+
// CHECK: [[AREF:%.*]] = nvws.aref.create
19+
%1 = nvws.aref.create %0 : <[!ttg.memdesc<3x1xi32, #shared, #smem, mutable>]>
20+
// CHECK: [[IDX:%.*]]:6 = scf.for [[I:%.*]] = [[LB:%.*]] to [[UB:%.*]] step [[STEP:%.*]] iter_args([[S0:%.*]] = [[C2]], [[P0:%.*]] = [[C0]], [[S1:%.*]] = [[C2]], [[P1:%.*]] = [[C1]], [[S2:%.*]] = [[C2]], [[P2:%.*]] = [[C1]])
21+
scf.for %arg3 = %arg0 to %arg1 step %arg2 : i32 {
22+
%2 = "op_a"() {ttg.partition = 0 : i32} : () -> tensor<1xi32, #blocked>
23+
// CHECK: op_a
24+
// CHECK-NEXT: [[S0a:%.*]] = arith.addi [[S0]], [[C1]]
25+
// CHECK-NEXT: [[CMP:%.*]] = arith.cmpi eq, [[S0a]], [[C3]]
26+
// CHECK-NEXT: [[S0b:%.*]] = arith.select [[CMP]], [[C0]], [[S0a]]
27+
// CHECK-NEXT: [[P0a:%.*]] = arith.xori [[P0]], [[C1]]
28+
// CHECK-NEXT: [[P0b:%.*]] = arith.select [[CMP]], [[P0a]], [[P0]]
29+
// CHECK-NEXT: put.enter [[AREF]][[[S0b]], [[P0b]]]
30+
%buffers, %token = nvws.aref.put.enter %1[%c0_i32, %c0_i32] {ttg.partition = 0 : i32} : <[!ttg.memdesc<3x1xi32, #shared, #smem, mutable>]> -> !ttg.memdesc<1xi32, #shared, #smem, mutable, 1x1>, !ttg.async.token
31+
ttg.local_store %2, %buffers {ttg.partition = 0 : i32} : tensor<1xi32, #blocked> -> !ttg.memdesc<1xi32, #shared, #smem, mutable, 1x1>
32+
// CHECK: put.exit [[AREF]][[[S0b]]]
33+
nvws.aref.put.exit %1[%c0_i32], %token [#nvws.async_op<none>] {ttg.partition = 0 : i32} : <[!ttg.memdesc<3x1xi32, #shared, #smem, mutable>]>, !ttg.async.token
34+
35+
// CHECK-NEXT: [[S1a:%.*]] = arith.addi [[S1]], [[C1]]
36+
// CHECK-NEXT: [[CMP:%.*]] = arith.cmpi eq, [[S1a]], [[C3]]
37+
// CHECK-NEXT: [[S1b:%.*]] = arith.select [[CMP]], [[C0]], [[S1a]]
38+
// CHECK-NEXT: [[P1a:%.*]] = arith.xori [[P1]], [[C1]]
39+
// CHECK-NEXT: [[P1b:%.*]] = arith.select [[CMP]], [[P1a]], [[P1]]
40+
// CHECK-NEXT: {{.*}}, [[TOK1:%.*]] = nvws.aref.get.enter [[AREF]][[[S1b]], [[P1b]]] {ttg.partition = 1 : i32}
41+
%buffers_0, %token_1 = nvws.aref.get.enter %1[%c0_i32, %c0_i32] {ttg.partition = 1 : i32} : <[!ttg.memdesc<3x1xi32, #shared, #smem, mutable>]> -> !ttg.memdesc<1xi32, #shared, #smem, mutable, 1x1>, !ttg.async.token
42+
%3 = ttg.local_load %buffers_0 {ttg.partition = 1 : i32} : !ttg.memdesc<1xi32, #shared, #smem, mutable, 1x1> -> tensor<1xi32, #blocked>
43+
// CHECK: get.exit [[AREF]][[[S1b]]], [[TOK1]] [#nvws.async_op<none>] {ttg.partition = 1 : i32}
44+
nvws.aref.get.exit %1[%c0_i32], %token_1 [#nvws.async_op<none>] {ttg.partition = 1 : i32} : <[!ttg.memdesc<3x1xi32, #shared, #smem, mutable>]>, !ttg.async.token
45+
"op_b"(%3) {ttg.partition = 1 : i32} : (tensor<1xi32, #blocked>) -> ()
46+
47+
// CHECK: op_b
48+
// CHECK-NEXT: [[S2a:%.*]] = arith.addi [[S2]], [[C1]]
49+
// CHECK-NEXT: [[CMP:%.*]] = arith.cmpi eq, [[S2a]], [[C3]]
50+
// CHECK-NEXT: [[S2b:%.*]] = arith.select [[CMP]], [[C0]], [[S2a]]
51+
// CHECK-NEXT: [[P2a:%.*]] = arith.xori [[P2]], [[C1]]
52+
// CHECK-NEXT: [[P2b:%.*]] = arith.select [[CMP]], [[P2a]], [[P2]]
53+
// CHECK-NEXT: {{.*}}, [[TOK2:%.*]] = nvws.aref.get.enter [[AREF]][[[S2b]], [[P2b]]] {ttg.partition = 2 : i32}
54+
%buffers_2, %token_3 = nvws.aref.get.enter %1[%c0_i32, %c0_i32] {ttg.partition = 2 : i32} : <[!ttg.memdesc<3x1xi32, #shared, #smem, mutable>]> -> !ttg.memdesc<1xi32, #shared, #smem, mutable, 1x1>, !ttg.async.token
55+
%4 = ttg.local_load %buffers_2 {ttg.partition = 2 : i32} : !ttg.memdesc<1xi32, #shared, #smem, mutable, 1x1> -> tensor<1xi32, #blocked>
56+
// CHECK: get.exit [[AREF]][[[S2b]]], [[TOK2]] [#nvws.async_op<none>] {ttg.partition = 2 : i32}
57+
nvws.aref.get.exit %1[%c0_i32], %token_3 [#nvws.async_op<none>] {ttg.partition = 2 : i32} : <[!ttg.memdesc<3x1xi32, #shared, #smem, mutable>]>, !ttg.async.token
58+
"op_c"(%4) {ttg.partition = 2 : i32} : (tensor<1xi32, #blocked>) -> ()
59+
"op_d"(%4) {ttg.partition = 2 : i32} : (tensor<1xi32, #blocked>) -> ()
60+
// CHECK: op_c
61+
// CHECK-NEXT: op_d
62+
// CHECK-NEXT: yield [[S0b]], [[P0b]], [[S1b]], [[P1b]], [[S2b]], [[P2b]]
63+
64+
} {ttg.paArtition.stages = [0 : i32, 2 : i32, 2 : i32], ttg.warp_specialize.tag = 0 : i32}
65+
66+
ttg.local_dealloc %0 : !ttg.memdesc<3x1xi32, #shared, #smem, mutable>
67+
tt.return
68+
}
69+
70+
}
71+
72+
// -----
73+
74+
#shared0 = #ttg.nvmma_shared<{swizzlingByteWidth = 32, transposed = false, elementBitWidth = 16}>
75+
#smem = #ttg.shared_memory
76+
module attributes {"ttg.target" = "cuda:0", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} {
77+
//CHECK-LABEL: @aref_lowering
78+
tt.func @aref_lowering(%d : !ttg.memdesc<3x64x16xf16, #shared0, #smem>,
79+
%e : !ttg.memdesc<3x16x32xf16, #shared0, #smem>,
80+
%cond : i1) {
81+
// CHECK: [[C3:%.*]] = arith.constant 3 : i32
82+
// CHECK: [[C2:%.*]] = arith.constant 2 : i32
83+
// CHECK: [[C0:%.*]] = arith.constant 0 : i32
84+
// CHECK: [[C1:%.*]] = arith.constant 1 : i32
85+
%c0_i32 = arith.constant 0 : i32
86+
%c1_i32 = arith.constant 1 : i32
87+
%lb = arith.constant 0 : i32
88+
%ub = arith.constant 4 : i32
89+
90+
// CHECK: [[AREF0:%.*]] = nvws.aref.create
91+
// CHECK-NEXT: [[AREF1:%.*]] = nvws.aref.create
92+
%aref0 = nvws.aref.create %d, %e : !nvws.aref<[!ttg.memdesc<3x64x16xf16, #shared0, #smem>, !ttg.memdesc<3x16x32xf16, #shared0, #smem>]>
93+
%aref1 = nvws.aref.create %d, %e : !nvws.aref<[!ttg.memdesc<3x64x16xf16, #shared0, #smem>, !ttg.memdesc<3x16x32xf16, #shared0, #smem>]>
94+
// CHECK: [[IDX:%.*]]:8 = scf.for [[I:%.*]] = [[LB:%.*]] to [[UB:%.*]] step [[STEP:%.*]] iter_args([[S0:%.*]] = [[C2]], [[P0:%.*]] = [[C0]], [[S1:%.*]] = [[C2]], [[P1:%.*]] = [[C1]], [[S2:%.*]] = [[C2]], [[P2:%.*]] = [[C0]], [[S3:%.*]] = [[C2]], [[P3:%.*]] = [[C1]])
95+
scf.for %i = %lb to %ub step %c1_i32 : i32{
96+
// CHECK: [[S0a:%.*]] = arith.addi [[S0]], [[C1]]
97+
// CHECK-NEXT: [[CMP:%.*]] = arith.cmpi eq, [[S0a]], [[C3]]
98+
// CHECK-NEXT: [[S0b:%.*]] = arith.select [[CMP]], [[C0]], [[S0a]]
99+
// CHECK-NEXT: [[P0a:%.*]] = arith.xori [[P0]], [[C1]]
100+
// CHECK-NEXT: [[P0b:%.*]] = arith.select [[CMP]], [[P0a]], [[P0]]
101+
// CHECK-NEXT: put.enter [[AREF0]][[[S0b]], [[P0b]]]
102+
%1:3 = nvws.aref.put.enter %aref0[%c0_i32, %c0_i32] {ttg.partition = 0 : i32} : !nvws.aref<[!ttg.memdesc<3x64x16xf16, #shared0, #smem>, !ttg.memdesc<3x16x32xf16, #shared0, #smem>]> -> !ttg.memdesc<64x16xf16, #shared0, #smem>, !ttg.memdesc<16x32xf16, #shared0, #smem>, !ttg.async.token
103+
"op1"(%1#0) {ttg.partition = 0 : i32}: (!ttg.memdesc<64x16xf16, #shared0, #smem>) -> ()
104+
"op2"(%1#1) {ttg.partition = 0 : i32} : (!ttg.memdesc<16x32xf16, #shared0, #smem>) -> ()
105+
// CHECK: op2
106+
// CHECK-NEXT: put.exit [[AREF0]][[[S0b]]]
107+
nvws.aref.put.exit %aref0[%c0_i32], %1#2 [#nvws.async_op<tma_load>, #nvws.async_op<none>] {ttg.partition = 0 : i32} : !nvws.aref<[!ttg.memdesc<3x64x16xf16, #shared0, #smem>, !ttg.memdesc<3x16x32xf16, #shared0, #smem>]>, !ttg.async.token
108+
109+
110+
// CHECK-NEXT: [[S1a:%.*]] = arith.addi [[S1]], [[C1]]
111+
// CHECK-NEXT: [[CMP:%.*]] = arith.cmpi eq, [[S1a]], [[C3]]
112+
// CHECK-NEXT: [[S1b:%.*]] = arith.select [[CMP]], [[C0]], [[S1a]]
113+
// CHECK-NEXT: [[P1a:%.*]] = arith.xori [[P1]], [[C1]]
114+
// CHECK-NEXT: [[P1b:%.*]] = arith.select [[CMP]], [[P1a]], [[P1]]
115+
// CHECK-NEXT: {{.*}}, [[TOK1:%.*]] = nvws.aref.get.enter [[AREF0]][[[S1b]], [[P1b]]] {ttg.partition = 1 : i32}
116+
%2:3 = nvws.aref.get.enter %aref0[%c0_i32, %c0_i32] {ttg.partition = 1 : i32} : !nvws.aref<[!ttg.memdesc<3x64x16xf16, #shared0, #smem>, !ttg.memdesc<3x16x32xf16, #shared0, #smem>]> -> !ttg.memdesc<64x16xf16, #shared0, #smem>, !ttg.memdesc<16x32xf16, #shared0, #smem>, !ttg.async.token
117+
"op3"(%2#0, %2#1) {ttg.partition = 1 : i32}: (!ttg.memdesc<64x16xf16, #shared0, #smem>, !ttg.memdesc<16x32xf16, #shared0, #smem>) -> ()
118+
// CHECK: op3
119+
// CHECK-NEXT: get.exit [[AREF0]][[[S1b]]], [[TOK1]] [#nvws.async_op<tc5mma>] {ttg.partition = 1 : i32}
120+
nvws.aref.get.exit %aref0[%c0_i32], %2#2 [#nvws.async_op<tc5mma>] {ttg.partition = 1 : i32} : !nvws.aref<[!ttg.memdesc<3x64x16xf16, #shared0, #smem>, !ttg.memdesc<3x16x32xf16, #shared0, #smem>]>, !ttg.async.token
121+
// CHECK: [[IDX1:%.*]]:4 = scf.if
122+
scf.if %cond {
123+
// CHECK-NEXT: [[S2a:%.*]] = arith.addi [[S2]], [[C1]]
124+
// CHECK-NEXT: [[CMP:%.*]] = arith.cmpi eq, [[S2a]], [[C3]]
125+
// CHECK-NEXT: [[S2b:%.*]] = arith.select [[CMP]], [[C0]], [[S2a]]
126+
// CHECK-NEXT: [[P2a:%.*]] = arith.xori [[P2]], [[C1]]
127+
// CHECK-NEXT: [[P2b:%.*]] = arith.select [[CMP]], [[P2a]], [[P2]]
128+
// CHECK-NEXT: {{.*}}, [[TOK2:%.*]] = nvws.aref.put.enter [[AREF1]][[[S2b]], [[P2b]]] {ttg.partition = 0 : i32}
129+
// CHECK-NEXT: op4
130+
// CHECK-NEXT: put.exit [[AREF1]][[[S2b]]]
131+
%4:3 = nvws.aref.put.enter %aref1[%c0_i32, %c0_i32] {ttg.partition = 0 : i32} : !nvws.aref<[!ttg.memdesc<3x64x16xf16, #shared0, #smem>, !ttg.memdesc<3x16x32xf16, #shared0, #smem>]> -> !ttg.memdesc<64x16xf16, #shared0, #smem>, !ttg.memdesc<16x32xf16, #shared0, #smem>, !ttg.async.token
132+
"op4"(%4#0, %4#1) {ttg.partition = 0 : i32} : (!ttg.memdesc<64x16xf16, #shared0, #smem>, !ttg.memdesc<16x32xf16, #shared0, #smem>) -> ()
133+
nvws.aref.put.exit %aref1[%c0_i32], %4#2 [#nvws.async_op<none>] {ttg.partition = 0 : i32} : !nvws.aref<[!ttg.memdesc<3x64x16xf16, #shared0, #smem>, !ttg.memdesc<3x16x32xf16, #shared0, #smem>]>, !ttg.async.token
134+
// CHECK-NEXT: [[S3a:%.*]] = arith.addi [[S3]], [[C1]]
135+
// CHECK-NEXT: [[CMP:%.*]] = arith.cmpi eq, [[S3a]], [[C3]]
136+
// CHECK-NEXT: [[S3b:%.*]] = arith.select [[CMP]], [[C0]], [[S3a]]
137+
// CHECK-NEXT: [[P3a:%.*]] = arith.xori [[P3]], [[C1]]
138+
// CHECK-NEXT: [[P3b:%.*]] = arith.select [[CMP]], [[P3a]], [[P3]]
139+
// CHECK-NEXT: {{.*}}, [[TOK3:%.*]] = nvws.aref.get.enter [[AREF1]][[[S3b]], [[P3b]]] {ttg.partition = 1 : i32}
140+
// CHECK-NEXT: op5
141+
// CHECK-NEXT: get.exit [[AREF1]][[[S3b]]], [[TOK3]] [#nvws.async_op<tc5mma>] {ttg.partition = 1 : i32}
142+
%5:3 = nvws.aref.get.enter %aref1[%c0_i32, %c0_i32] {ttg.partition = 1 : i32} : !nvws.aref<[!ttg.memdesc<3x64x16xf16, #shared0, #smem>, !ttg.memdesc<3x16x32xf16, #shared0, #smem>]> -> !ttg.memdesc<64x16xf16, #shared0, #smem>, !ttg.memdesc<16x32xf16, #shared0, #smem>, !ttg.async.token
143+
"op5"(%5#0, %5#1) {ttg.partition = 1 : i32}: (!ttg.memdesc<64x16xf16, #shared0, #smem>, !ttg.memdesc<16x32xf16, #shared0, #smem>) -> ()
144+
nvws.aref.get.exit %aref1[%c0_i32], %5#2 [#nvws.async_op<tc5mma>] {ttg.partition = 1 : i32} : !nvws.aref<[!ttg.memdesc<3x64x16xf16, #shared0, #smem>, !ttg.memdesc<3x16x32xf16, #shared0, #smem>]>, !ttg.async.token
145+
// CHECK-NEXT: yield [[S2b]], [[P2b]], [[S3b]], [[P3b]]
146+
}
147+
// CHECK-NEXT: } else {
148+
// CHECK-NEXT: yield [[S2]], [[P2]], [[S3]], [[P3]]
149+
// CHECK-NEXT: }
150+
// CHECK: scf.yield [[S0b]], [[P0b]], [[S1b]], [[P1b]], [[IDX1]]#0, [[IDX1]]#1, [[IDX1]]#2, [[IDX1]]#3
151+
152+
} {ttg.warp_specialize.tag = 0 : i32}
153+
tt.return
154+
}
155+
}

0 commit comments

Comments
 (0)