Skip to content

Commit c7a9f29

Browse files
Disabling mmav5 pipelining if there are two or more dots in the loop (#6036)
MMAv5 pipelining has problem with keeping track of async loads feeding into ops at different pipeline stages. All the corner cases are difficult to maintain and problems with this keep arising. I am disabling the support for pipelining when there are more than one dot ops in the loop. This means in most kernels all the loads will be feeding to the dot op, or some operation preceding, resolving the issue. It also means attention kernels won't be pipelined for now, until we finish integrating mmav5 pipelining pass into the new pipeliner infrastructure.
1 parent 2f17e8b commit c7a9f29

File tree

2 files changed

+7
-13
lines changed

2 files changed

+7
-13
lines changed

lib/Dialect/TritonGPU/Transforms/Pipeliner/TC05MMAPipeline.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -615,7 +615,10 @@ FailureOr<scf::ForOp> preProcessLoopForTC05MMAPipelining(scf::ForOp forOp,
615615
}
616616
});
617617

618-
if (mmaOps.empty()) {
618+
// Temporarily disable mma pipelining if there are more than one mmaOp in the
619+
// loop. This is a workaround for difficult to solve scheduling issues with
620+
// loads feeding into non-0 stage ops.
621+
if (mmaOps.empty() || mmaOps.size() > 1) {
619622
return failure();
620623
}
621624

test/TritonGPU/mma-pipeline-blackwell.mlir

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -259,20 +259,11 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
259259
}
260260

261261
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} {
262-
// CHECK-LOWER-LABEL: @do_not_pipeline_second_dot
263-
// CHECK-LOWER: scf.for {{.*}}
264-
// CHECK-LOWER: ttng.tmem_store {{.*}} {triton.pipeline_stage = 0 : i32}
265-
// CHECK-LOWER: ttng.tc_gen5_mma {{.*}} {triton.pipeline_stage = 0 : i32}
266-
// CHECK-LOWER: ttng.tmem_load {{.*}} {triton.pipeline_stage = 1 : i32}
267-
// CHECK-LOWER: ttng.tmem_alloc
268-
// CHECK-LOWER-NOT: triton.pipeline_stage
269-
// CHECK-LOWER: ttng.tc_gen5_mma
270-
// CHECK-LOWER-NOT: triton.pipeline_stage
271-
// CHECK-LOWER: ttng.tmem_load
262+
// CHECK-LOWER-LABEL: @do_not_pipeline_two_dots
272263
// CHECK-LOWER-NOT: triton.pipeline_stage
273264

274-
// CHECK-LABEL: @do_not_pipeline_second_dot
275-
tt.func public @do_not_pipeline_second_dot(%A_ptr: tensor<128x128x!tt.ptr<f16>, #blocked1>, %B_ptr: tensor<128x128x!tt.ptr<f16>, #blocked1>, %acc_ptr: tensor<128x128x!tt.ptr<f32>, #blocked>, %res_ptr: tensor<128x128x!tt.ptr<f32>, #blocked>, %arg3: i32) attributes {noinline = false} {
265+
// CHECK-LABEL: @do_not_pipeline_two_dots
266+
tt.func public @do_not_pipeline_two_dots(%A_ptr: tensor<128x128x!tt.ptr<f16>, #blocked1>, %B_ptr: tensor<128x128x!tt.ptr<f16>, #blocked1>, %acc_ptr: tensor<128x128x!tt.ptr<f32>, #blocked>, %res_ptr: tensor<128x128x!tt.ptr<f32>, #blocked>, %arg3: i32) attributes {noinline = false} {
276267
%true = arith.constant true
277268
%cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked>
278269
%c0_i32 = arith.constant 0 : i32

0 commit comments

Comments
 (0)