Skip to content

Commit 04159ed

Browse files
[PIPELINER] Adding a marker for loop scheduling serialization (#6037)
Marking a loop with serialized schedule data with an attribute, so getting the maximum scheduling stage is both faster and more reliable.
1 parent c7a9f29 commit 04159ed

File tree

5 files changed

+32
-27
lines changed

5 files changed

+32
-27
lines changed

include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ static const char *kDisallowAccMultiBufferAttrName =
1515
"tt.disallow_acc_multi_buffer";
1616
static const char *kLoopStageAttrName = "loop.stage";
1717
static const char *kLoopClusterAttrName = "loop.cluster";
18+
static const char *kScheduledMaxStageAttrName = "tt.scheduled_max_stage";
1819
static const char *kLatencyAttrName = "tt.latency";
1920

2021
bool loopHasDistGreaterThanOne(scf::ForOp forOp);

lib/Dialect/TritonGPU/Transforms/Pipeliner/Schedule.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -172,12 +172,11 @@ static std::pair<int, int> getMinMaxCluster(scf::ForOp &forOp) {
172172

173173
static std::optional<int> tryGetMaxStage(scf::ForOp &forOp) {
174174
std::optional<int> maxStage = std::nullopt;
175-
for (auto &op : forOp.getBody()->without_terminator()) {
176-
if (!op.hasAttr(mlir::triton::kLoopStageAttrName) ||
177-
!op.hasAttr(mlir::triton::kLoopClusterAttrName))
178-
continue;
179-
auto [stage, _] = getStageCluster(&op);
180-
maxStage = maxStage ? (stage > *maxStage ? stage : *maxStage) : stage;
175+
if (forOp->hasAttr(mlir::triton::kScheduledMaxStageAttrName)) {
176+
return forOp
177+
->getAttrOfType<IntegerAttr>(mlir::triton::kScheduledMaxStageAttrName)
178+
.getValue()
179+
.getSExtValue();
181180
}
182181
return maxStage;
183182
}
@@ -187,6 +186,9 @@ void tt::CoarseSchedule::serialize(scf::ForOp &forOp) {
187186
for (auto [op, stage, cluster] : getOpsInOrder(forOp)) {
188187
setStageCluster(op, stage, *cluster);
189188
}
189+
forOp->setAttr(mlir::triton::kScheduledMaxStageAttrName,
190+
IntegerAttr::get(IntegerType::get(forOp.getContext(), 32),
191+
numStages - 1));
190192
}
191193

192194
// Create a CoarseSchedule based on forOp's <stage, cluster>.

lib/Dialect/TritonGPU/Transforms/Pipeliner/SoftwarePipeliner.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ static void removeAttributes(ModuleOp moduleOp) {
7272
moduleOp->walk([&](Operation *op) {
7373
op->removeAttr(mlir::triton::kLoopStageAttrName);
7474
op->removeAttr(mlir::triton::kLoopClusterAttrName);
75+
op->removeAttr(mlir::triton::kScheduledMaxStageAttrName);
7576
});
7677
}
7778

test/TritonGPU/pipeline-lower-loop.mlir

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ tt.func @one_dep_async(%lb : index, %ub : index, %step : index,
5151
scf.for %iv = %lb to %ub step %step : index {
5252
%a = tt.load %a_ptr_init {loop.cluster = 2 : i32, loop.stage = 0 : i32} : tensor<128x32x!tt.ptr<f16>, #A>
5353
"use"(%a) {loop.cluster = 0 : i32, loop.stage = 2 : i32} : (tensor<128x32xf16, #A>) -> ()
54-
}
54+
} {tt.scheduled_max_stage = 2 : i32}
5555
tt.return
5656
}
5757
}
@@ -75,7 +75,7 @@ tt.func @different_use_stages(%lb : index, %ub : index, %step : index,
7575
%a = tt.load %a_ptr_init {loop.cluster = 2 : i32, loop.stage = 0 : i32} : tensor<128x32x!tt.ptr<f16>, #A>
7676
"use1"(%a) {loop.cluster = 0 : i32, loop.stage = 2 : i32} : (tensor<128x32xf16, #A>) -> ()
7777
"use2"(%a) {loop.cluster = 0 : i32, loop.stage = 3 : i32} : (tensor<128x32xf16, #A>) -> ()
78-
}
78+
} {tt.scheduled_max_stage = 3 : i32}
7979
tt.return
8080
}
8181
}
@@ -106,7 +106,7 @@ tt.func @used_by_if_yield(%lb : index, %ub : index, %step : index,
106106
scf.yield %init_a : tensor<128x32xf16, #A>
107107
} {loop.cluster = 0 : i32, loop.stage = 2 : i32}
108108
"use"(%a_if) {loop.cluster = 0 : i32, loop.stage = 3 : i32} : (tensor<128x32xf16, #A>) -> ()
109-
}
109+
} {tt.scheduled_max_stage = 3 : i32}
110110
tt.return
111111
}
112112
}
@@ -124,7 +124,7 @@ tt.func @dist1_load(%lb : index, %ub : index, %step : index,
124124
%a = tt.load %a_ptr_init {loop.cluster = 2 : i32, loop.stage = 0 : i32} : tensor<128x32x!tt.ptr<f16>, #A>
125125
"use"(%a) {loop.cluster = 0 : i32, loop.stage = 2 : i32} : (tensor<128x32xf16, #A>) -> ()
126126
scf.yield %a : tensor<128x32xf16, #A>
127-
}
127+
} {tt.scheduled_max_stage = 2 : i32}
128128
tt.return
129129
}
130130
}
@@ -142,7 +142,7 @@ tt.func @one_dep_sync(%lb : index, %ub : index, %step : index,
142142
scf.for %iv = %lb to %ub step %step : index {
143143
%a = tt.load %a_ptr_init {loop.cluster = 2 : i32, loop.stage = 0 : i32} : tensor<1x!tt.ptr<f16>, #A>
144144
"use"(%a) {loop.cluster = 0 : i32, loop.stage = 2 : i32} : (tensor<1xf16, #A>) -> ()
145-
}
145+
} {tt.scheduled_max_stage = 2 : i32}
146146
tt.return
147147
}
148148
}
@@ -183,7 +183,7 @@ tt.func @one_dep_local_alloc(%lb : index, %ub : index, %step : index,
183183
%a_alloc = ttg.local_alloc %a {loop.cluster = 0 : i32, loop.stage = 2 : i32} : (tensor<128x32xf16, #A>) -> !ttg.memdesc<128x32xf16, #shared, #ttg.shared_memory, mutable>
184184
%a_load = ttg.local_load %a_alloc {loop.cluster = 0 : i32, loop.stage = 2 : i32} : !ttg.memdesc<128x32xf16, #shared, #ttg.shared_memory, mutable> -> tensor<128x32xf16, #A>
185185
"use"(%a_load) {loop.cluster = 0 : i32, loop.stage = 2 : i32} : (tensor<128x32xf16, #A>) -> ()
186-
}
186+
} {tt.scheduled_max_stage = 2 : i32}
187187
tt.return
188188
}
189189
}
@@ -214,7 +214,7 @@ tt.func @one_load_group(%lb : index, %ub : index, %step : index,
214214
%b = tt.load %a_ptr_init {loop.cluster = 2 : i32, loop.stage = 0 : i32} : tensor<128x32x!tt.ptr<f32>, #A>
215215
"use1"(%a){loop.cluster = 0 : i32, loop.stage = 2 : i32} : (tensor<128x32xf32, #A>) -> ()
216216
"use2"(%b){loop.cluster = 0 : i32, loop.stage = 2 : i32} : (tensor<128x32xf32, #A>) -> ()
217-
}
217+
} {tt.scheduled_max_stage = 2 : i32}
218218
tt.return
219219
}
220220
}
@@ -255,7 +255,7 @@ tt.func @two_load_groups(%lb : index, %ub : index, %step : index,
255255
"use1"(%a){loop.cluster = 0 : i32, loop.stage = 2 : i32} : (tensor<128x32xf32, #A>) -> ()
256256
"use2"(%b){loop.cluster = 0 : i32, loop.stage = 2 : i32} : (tensor<128x32xf32, #A>) -> ()
257257
"use3"(%c){loop.cluster = 0 : i32, loop.stage = 3 : i32} : (tensor<128x32xf32, #A>) -> ()
258-
}
258+
} {tt.scheduled_max_stage = 3 : i32}
259259
tt.return
260260
}
261261
}
@@ -304,7 +304,7 @@ tt.func @dependent_loads(%lb : index, %ub : index, %step : index,
304304
%b = "pointerize"(%a) {loop.cluster = 2 : i32, loop.stage = 2 : i32} : (tensor<128x32xf32, #A>) -> tensor<128x32x!tt.ptr<f32>, #A>
305305
%c = tt.load %b {loop.cluster = 2 : i32, loop.stage = 2 : i32} : tensor<128x32x!tt.ptr<f32>, #A>
306306
"use1"(%c){loop.cluster = 0 : i32, loop.stage = 4 : i32} : (tensor<128x32xf32, #A>) -> ()
307-
}
307+
} {tt.scheduled_max_stage = 4 : i32}
308308
tt.return
309309
}
310310
}
@@ -361,7 +361,7 @@ tt.func @dependent_loads_asymmetric(%lb : index, %ub : index, %step : index,
361361
%b = "pointerize"(%a) {loop.cluster = 2 : i32, loop.stage = 2 : i32} : (tensor<128x32xf32, #A>) -> tensor<128x32x!tt.ptr<f32>, #A>
362362
%c = tt.load %b {loop.cluster = 2 : i32, loop.stage = 2 : i32} : tensor<128x32x!tt.ptr<f32>, #A>
363363
"use1"(%c){loop.cluster = 0 : i32, loop.stage = 5 : i32} : (tensor<128x32xf32, #A>) -> ()
364-
}
364+
} {tt.scheduled_max_stage = 5 : i32}
365365
tt.return
366366
}
367367
}
@@ -379,7 +379,7 @@ tt.func @unused_load(%lb : index, %ub : index, %step : index,
379379
// CHECK: dummy
380380
%a = tt.load %a_ptr_init {loop.cluster = 0 : i32, loop.stage = 1 : i32} : tensor<128x32x!tt.ptr<f32>, #A>
381381
"dummy"() : () -> ()
382-
}
382+
} {tt.scheduled_max_stage = 1 : i32}
383383
tt.return
384384
}
385385
}
@@ -434,7 +434,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
434434
%B_sh = ttg.local_alloc %B {loop.cluster = 0 : i32, loop.stage = 2 : i32} : (tensor<128x128xf16, #blocked1>) -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory>
435435
%acc_res = ttng.warp_group_dot %A_sh, %B_sh, %acc {loop.cluster = 0 : i32, loop.stage = 2 : i32} : !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory> * !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory> -> tensor<128x128xf32, #mma>
436436
scf.yield %acc_res : tensor<128x128xf32, #mma>
437-
}
437+
} {tt.scheduled_max_stage = 2 : i32}
438438
%res_f16 = arith.truncf %res : tensor<128x128xf32, #mma> to tensor<128x128xf16, #mma>
439439
tt.return %res_f16 : tensor<128x128xf16, #mma>
440440
}
@@ -489,7 +489,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
489489
%B_sh = ttg.local_alloc %B {loop.cluster = 0 : i32, loop.stage = 2 : i32} : (tensor<128x128xf16, #blocked1>) -> !ttg.memdesc<128x128xf16, #shared1, #ttg.shared_memory>
490490
%acc_res = ttng.warp_group_dot %A_sh, %B_sh, %acc {loop.cluster = 0 : i32, loop.stage = 2 : i32} : !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory> * !ttg.memdesc<128x128xf16, #shared1, #ttg.shared_memory> -> tensor<128x128xf32, #mma>
491491
scf.yield %acc_res : tensor<128x128xf32, #mma>
492-
}
492+
} {tt.scheduled_max_stage = 2 : i32}
493493
tt.return %res : tensor<128x128xf32, #mma>
494494
}
495495
}
@@ -555,7 +555,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
555555
%B_sh = ttg.local_alloc %B {loop.cluster = 0 : i32, loop.stage = 2 : i32} : (tensor<128x128xf16, #blocked1>) -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory>
556556
%acc_res = ttng.warp_group_dot %A_sh, %B_sh, %acc {loop.cluster = 0 : i32, loop.stage = 2 : i32} : !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory> * !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory> -> tensor<128x128xf32, #mma>
557557
scf.yield %acc_res : tensor<128x128xf32, #mma>
558-
}
558+
} {tt.scheduled_max_stage = 2 : i32}
559559
%res_f16 = arith.truncf %res : tensor<128x128xf32, #mma> to tensor<128x128xf16, #mma>
560560
tt.return %res_f16 : tensor<128x128xf16, #mma>
561561
}
@@ -614,7 +614,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
614614
ttng.tc_gen5_mma %A_sh, %B_sh, %acc_tm, %true, %true {loop.cluster = 0 : i32, loop.stage = 2 : i32} : (!ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory>, !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory>, i1, i1) -> ()
615615
%acc_res = ttng.tmem_load %acc_tm {loop.cluster = 0 : i32, loop.stage = 2 : i32} : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory> -> tensor<128x128xf32, #blocked>
616616
scf.yield %acc_res : tensor<128x128xf32, #blocked>
617-
}
617+
} {tt.scheduled_max_stage = 2 : i32}
618618
%res_f16 = arith.truncf %res : tensor<128x128xf32, #blocked> to tensor<128x128xf16, #blocked>
619619
tt.return %res_f16 : tensor<128x128xf16, #blocked>
620620
}
@@ -669,7 +669,7 @@ tt.func @tma_load_lowering(%lb : index, %ub : index, %step : index,
669669
scf.for %iv = %lb to %ub step %step : index {
670670
%a = tt.experimental_descriptor_load %desc[%offs, %offs] {loop.cluster = 2 : i32, loop.stage = 0 : i32} : !tt.tensordesc<tensor<128x32xf16>> -> tensor<128x32xf16, #A>
671671
"use"(%a) {loop.cluster = 0 : i32, loop.stage = 2 : i32} : (tensor<128x32xf16, #A>) -> ()
672-
}
672+
} {tt.scheduled_max_stage = 2 : i32}
673673
tt.return
674674
}
675675
}
@@ -725,7 +725,7 @@ tt.func @tma_gather_lowering(%lb : index, %ub : index, %step : index,
725725
scf.for %iv = %lb to %ub step %step : index {
726726
%a = tt.experimental_descriptor_gather %desc[%x, %y] {loop.cluster = 2 : i32, loop.stage = 0 : i32} : (!tt.tensordesc<tensor<1x128xf32>>, tensor<32xi32, #offsets>, i32) -> tensor<32x128xf32, #A>
727727
"use"(%a) {loop.cluster = 0 : i32, loop.stage = 2 : i32} : (tensor<32x128xf32, #A>) -> ()
728-
}
728+
} {tt.scheduled_max_stage = 2 : i32}
729729
tt.return
730730
}
731731
}
@@ -760,7 +760,7 @@ tt.func @tma_reuse_barrier(%lb : index, %ub : index, %step : index,
760760
"use2"(%b) {loop.cluster = 0 : i32, loop.stage = 2 : i32} : (tensor<128x32xf16, #A>) -> ()
761761
%c = tt.experimental_descriptor_load %descC[%offs, %offs] {loop.cluster = 2 : i32, loop.stage = 0 : i32} : !tt.tensordesc<tensor<128x32xf16>> -> tensor<128x32xf16, #A>
762762
"use3"(%c) {loop.cluster = 0 : i32, loop.stage = 2 : i32} : (tensor<128x32xf16, #A>) -> ()
763-
}
763+
} {tt.scheduled_max_stage = 2 : i32}
764764
tt.return
765765
}
766766
}
@@ -798,7 +798,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
798798
%B_sh = ttg.local_alloc %B {loop.cluster = 0 : i32, loop.stage = 2 : i32} : (tensor<128x128xf16, #blocked1>) -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory>
799799
%acc_res = ttng.warp_group_dot %A_sh, %B_sh, %acc {loop.cluster = 0 : i32, loop.stage = 2 : i32} : !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory> * !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory> -> tensor<128x128xf32, #mma>
800800
scf.yield %acc_res : tensor<128x128xf32, #mma>
801-
}
801+
} {tt.scheduled_max_stage = 2 : i32}
802802
%res_f16 = arith.truncf %res : tensor<128x128xf32, #mma> to tensor<128x128xf16, #mma>
803803
tt.return %res_f16 : tensor<128x128xf16, #mma>
804804
}
@@ -833,7 +833,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
833833
scf.for %iv = %lb to %ub step %step : index {
834834
%desc = tt.make_tensor_descriptor %A, [%shape_x, %shape_y], [%strides_x, %strides_y] {loop.cluster = 0 : i32, loop.stage = 1 : i32} : <f16>, <tensor<128x128xf16>>
835835
"use"(%desc) {loop.cluster = 0 : i32, loop.stage = 1 : i32} : (!tt.tensordesc<tensor<128x128xf16>>) -> ()
836-
}
836+
} {tt.scheduled_max_stage = 1 : i32}
837837
tt.return
838838
}
839839
}
@@ -879,7 +879,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
879879
ttng.tc_gen5_mma_scaled %A_sh, %B_sh, %acc_tm, %A_sc_sh, %B_sc_sh, %true, %true lhs = e5m2 rhs = e5m2 {loop.cluster = 0 : i32, loop.stage = 2 : i32} : (!ttg.memdesc<128x128xf8E5M2, #shared, #ttg.shared_memory>, !ttg.memdesc<128x128xf8E5M2, #shared, #ttg.shared_memory>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory>, !ttg.memdesc<1x2x32x4x4xi8, #shared1, #smem>, !ttg.memdesc<1x2x32x4x4xi8, #shared1, #smem>, i1, i1) -> ()
880880
%acc_res = ttng.tmem_load %acc_tm {loop.cluster = 0 : i32, loop.stage = 2 : i32} : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory> -> tensor<128x128xf32, #blocked>
881881
scf.yield %acc_res : tensor<128x128xf32, #blocked>
882-
}
882+
} {tt.scheduled_max_stage = 2 : i32}
883883
tt.return %res : tensor<128x128xf32, #blocked>
884884
}
885885
}

test/TritonGPU/pipeline-schedule-loop.mlir

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ tt.func @one_dep(%lb : index, %ub : index, %step : index,
2121
%res = arith.addf %acc, %a : tensor<128x32xf16, #A>
2222
scf.yield %res : tensor<128x32xf16, #A>
2323
}
24+
// CHECK: tt.scheduled_max_stage
2425
tt.return %loop#0 : tensor<128x32xf16, #A>
2526
}
2627

0 commit comments

Comments
 (0)