Skip to content

Commit 89bdc37

Browse files
authored
[Blackwell] Fix barrierSlice typing bug (#8414)
`createBarrierAlloc` always returns a [num_stages x 1] memory allocations for barriers. When num_stages=1 we still need to call `triton::createSingleBufferView` to match the type assumptions for the underlying barrier.
1 parent 09f1aa4 commit 89bdc37

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

lib/Dialect/TritonGPU/Transforms/Pipeliner/LowerLoops.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -754,17 +754,19 @@ void createBarrierAndWaitOps(scf::ForOp forOp, CoarseSchedule &schedule,
754754
Value barrierAlloc = createBarrierAlloc(forOp, numStages);
755755
Value vTrue = builder.create<arith::ConstantIntOp>(1, 1);
756756
Value phase = forOp.getRegionIterArg(phaseArgIdx);
757-
Value barrierIdx = forOp.getRegionIterArg(barrierIdxArgIdx);
758757
Value zero = builder.create<arith::ConstantIntOp>(forOp.getLoc(), 0, 32);
758+
Value barrierIdx;
759+
if (numStages > 1) {
760+
barrierIdx = forOp.getRegionIterArg(barrierIdxArgIdx);
761+
} else {
762+
barrierIdx = zero;
763+
}
759764
Value one = builder.create<arith::ConstantIntOp>(forOp.getLoc(), 1, 32);
760765
Value numStagesVal =
761766
builder.create<arith::ConstantIntOp>(forOp.getLoc(), numStages, 32);
762767

763-
Value barrierSlice = barrierAlloc;
764-
if (numStages > 1) {
765-
barrierSlice =
766-
triton::createSingleBufferView(builder, barrierAlloc, barrierIdx);
767-
}
768+
Value barrierSlice =
769+
triton::createSingleBufferView(builder, barrierAlloc, barrierIdx);
768770
mma.addCompletionBarrier(barrierSlice, vTrue);
769771
mma.setIsAsync(true);
770772

0 commit comments

Comments
 (0)