Skip to content

Commit 620237e

Browse files
authored
[Warp Specialization] Tigthen user crit section with acc is not multibuffered (#7509)
cc @aeng-openai
1 parent 679954d commit 620237e

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

lib/Dialect/TritonGPU/Transforms/WarpSpecialization/LoadMMASpecialization.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -752,7 +752,7 @@ static LogicalResult pipelineMMA(scf::ForOp &loop, PipelinedMMA &mma,
752752
}
753753
} else {
754754
b.setInsertionPoint(domOp);
755-
if (isa<scf::IfOp>(domOp->getParentOp()))
755+
if (isa<scf::IfOp>(domOp->getParentOp()) && accIsMultiBuffered)
756756
b.setInsertionPointToStart(domOp->getBlock());
757757
Value bar = createSingleBufferView(b, node.barPrev, node.index);
758758
b.createInto<ttng::WaitBarrierOp>(*partition, nodeStageCluster, bar,
@@ -766,7 +766,7 @@ static LogicalResult pipelineMMA(scf::ForOp &loop, PipelinedMMA &mma,
766766
mmaOp.addCompletionBarrier(bar, userPred);
767767
} else {
768768
b.setInsertionPointAfter(lastOp);
769-
if (isa<scf::IfOp>(lastOp->getParentOp()))
769+
if (isa<scf::IfOp>(lastOp->getParentOp()) && accIsMultiBuffered)
770770
b.setInsertionPoint(lastOp->getBlock()->getTerminator());
771771
Value bar = createSingleBufferView(b, node.barNext, node.index);
772772
b.createInto<ttng::ArriveBarrierOp>(*partition, nodeStageCluster, bar,

test/TritonGPU/load-mma-specialization.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -734,13 +734,13 @@ tt.func @matmul_tma_acc_with_conditional_def_and_use_no_multibuf_flag(
734734

735735
// CHECK-NEXT: scf.if [[DO_EPILOGUE]]
736736
scf.if %do_epilogue {
737-
// CHECK-NEXT: ttng.wait_barrier [[ACC_READY_BUF0]], [[ACC_PHASE]] {ttg.partition = 0 : i32}
738737
// CHECK-NEXT: "some_op"()
739738
"some_op"() : () -> ()
739+
// CHECK-NEXT: ttng.wait_barrier [[ACC_READY_BUF0]], [[ACC_PHASE]] {ttg.partition = 0 : i32}
740740
// CHECK-NEXT: [[C:%.*]], [[USER_TOK:%.*]] = ttng.tmem_load [[ACC_BUF]][]
741+
// CHECK-NEXT: ttng.arrive_barrier [[ACC_EMPTY_BUF0]], 1 {ttg.partition = 0 : i32}
741742
// CHECK-NEXT: "acc_user"([[C]])
742743
"acc_user"(%c) : (tensor<128x128xf32, #acc_layout>) -> ()
743-
// CHECK-NEXT: ttng.arrive_barrier [[ACC_EMPTY_BUF0]], 1 {ttg.partition = 0 : i32}
744744
// CHECK-NEXT: }
745745
}
746746

0 commit comments

Comments
 (0)