Skip to content

Commit e4e6687

Browse files
[BW][PIPELINE] Add an option to tl.range to disallow accumulator multi-buffering (triton-lang#5858)
Rework mmav5 pipelining to allow pipelining of mma when multibuffering of the accumulator is impossible by putting uses in the same stage as the mma and blocking on wait until current mma finishes. Based on this support, introducing new flag to `tl.range` that controls if multibuffering of the accumulator of the dots in the loop is allowed. Without the mentioned rework of mmav5 pipelining we would simply not pipeline cases where mutibuffering is disallowed.
1 parent de0f754 commit e4e6687

File tree

7 files changed

+291
-15
lines changed

7 files changed

+291
-15
lines changed

include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ namespace mlir {
1010
namespace triton {
1111

1212
static const char *kNumStagesAttrName = "tt.num_stages";
13+
static const char *kDisallowAccMultiBufferAttrName =
14+
"tt.disallow_acc_multi_buffer";
1315
static const char *kLoopStageAttrName = "loop.stage";
1416
static const char *kLoopClusterAttrName = "loop.cluster";
1517

@@ -37,6 +39,10 @@ void addOps(scf::ForOp forOp, int stage,
3739
void replaceUsesAndPropagateType(OpBuilder &builder, Operation *oldUse,
3840
Value val);
3941

42+
// Return true if the given ForOp has the attribute
43+
// `tt.disallow_acc_multi_buffer` set to true.
44+
bool getDisallowAccMultiBuffer(scf::ForOp forOp);
45+
4046
// Return the minClusterId and maxClusterId for the given ForOp.
4147
std::pair<int, int> getMinMaxCluster(scf::ForOp &forOp);
4248
std::pair<int, int> getStageCluster(Operation *op);

lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,12 @@ void mlir::triton::replaceUsesAndPropagateType(OpBuilder &builder,
234234
op->erase();
235235
}
236236

237+
// Return true if the given ForOp has the attribute
238+
// `tt.disallow_acc_multi_buffer` set to true.
239+
bool mlir::triton::getDisallowAccMultiBuffer(scf::ForOp forOp) {
240+
return forOp->hasAttr(mlir::triton::kDisallowAccMultiBufferAttrName);
241+
}
242+
237243
std::optional<std::pair<int, int>>
238244
mlir::triton::maybeGetStageCluster(Operation *op) {
239245
auto stage =

lib/Dialect/TritonGPU/Transforms/Pipeliner/TC05MMAPipeline.cpp

Lines changed: 67 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,17 @@ void annotateWithPipelineStage(IRRewriter &builder, Operation *op, int stage) {
3535
IntegerAttr::get(builder.getI32Type(), stage));
3636
}
3737

38+
int getPipelineStage(Operation *op) {
39+
return op->getAttrOfType<IntegerAttr>(kPipelineStageAttrName).getInt();
40+
}
41+
3842
struct MMAInfo {
3943
struct AccOverridePoint {
4044
Operation *op;
4145
Value condition = nullptr;
4246
Value initValue = nullptr;
4347
int distance = 0;
48+
bool isFlag = false;
4449
};
4550

4651
ttng::TMEMAllocOp accAlloc; // Directly precedes the dot, allocating tmem
@@ -136,6 +141,7 @@ std::optional<MMAInfo::AccOverridePoint>
136141
getAccOverridePointInLoop(scf::ForOp forOp, ttng::TMEMAllocOp accUse,
137142
ttng::TMEMLoadOp accDef) {
138143
MMAInfo::AccOverridePoint accOverridePoint;
144+
accOverridePoint.isFlag = false;
139145
DenseSet<Value> seen;
140146
Value v = accUse.getSrc();
141147
if (v == nullptr) {
@@ -219,6 +225,7 @@ getAccUseFlagFalseInLoop(scf::ForOp forOp, Value useAccFlagUse) {
219225

220226
IRRewriter builder(v.getDefiningOp()->getNextNode());
221227
MMAInfo::AccOverridePoint accOverridePoint;
228+
accOverridePoint.isFlag = true;
222229
accOverridePoint.distance = dist;
223230
Location loc = v.getDefiningOp()->getLoc();
224231
auto vTrue =
@@ -374,9 +381,12 @@ void updateAccUsesInLoop(IRRewriter &builder, scf::ForOp forOp, MMAInfo &info,
374381
}
375382
auto load = builder.create<ttng::TMEMLoadOp>(
376383
domOp->getLoc(), info.accLoad.getType(), extractSlice);
384+
// If accumulator is multi-buffered, it is implicit that we put the load
385+
// in the last stage.
386+
int pipelineStage = info.accIsMultiBuffered ? numStages - 1 : 0;
377387
annotateWithPipelineStage(
378388
builder, forOp.getBody()->findAncestorOpInBlock(*load.getOperation()),
379-
numStages - 1);
389+
pipelineStage);
380390
for (auto user : directUses) {
381391
user->replaceUsesOfWith(info.accLoad, load);
382392
}
@@ -574,12 +584,45 @@ void createBarrierAndWaitOps(IRRewriter &builder, scf::ForOp forOp,
574584
info.barrierIdx = newBarrierIdx;
575585
annotateWithPipelineStage(builder, info.barrierIdx.getDefiningOp(), 0);
576586

587+
Value originalPhase = info.phase;
577588
Value newPhase = builder.create<arith::SelectOp>(
578589
loc, info.phase.getType(), barWrap,
579590
builder.create<arith::XOrIOp>(loc, info.phase, one), info.phase);
580591
replaceAllUsesDominatedBy(newPhase.getDefiningOp(), newPhase, info.phase);
581592
info.phase = newPhase;
582593
annotateWithPipelineStage(builder, info.phase.getDefiningOp(), 0);
594+
595+
// We need to add a barrier before load from the accumulator, if it is in the
596+
// same stage as the dot.
597+
ttng::TMEMLoadOp tmemLoad = nullptr;
598+
SmallVector<Operation *> users = {info.accAlloc->getUsers().begin(),
599+
info.accAlloc->getUsers().end()};
600+
while (!users.empty()) {
601+
auto user = users.pop_back_val();
602+
if (isa<ttg::MemDescSubviewOp>(user)) {
603+
users.append(user->getUsers().begin(), user->getUsers().end());
604+
}
605+
if (isa<ttng::TMEMLoadOp>(user) && forOp->isAncestor(user)) {
606+
if (tmemLoad) {
607+
assert(tmemLoad == cast<ttng::TMEMLoadOp>(user) &&
608+
"Should have only one tmem load from the accumulator");
609+
}
610+
tmemLoad = cast<ttng::TMEMLoadOp>(user);
611+
}
612+
}
613+
if (tmemLoad) {
614+
int loadStage =
615+
getPipelineStage(forOp.getBody()->findAncestorOpInBlock(*tmemLoad));
616+
int mmaOpStage = getPipelineStage(mmaOp);
617+
if (loadStage == mmaOpStage) {
618+
builder.setInsertionPoint(tmemLoad);
619+
auto barrier =
620+
builder.create<ttng::WaitBarrierOp>(loc, barrierSlice, originalPhase);
621+
annotateWithPipelineStage(
622+
builder, forOp.getBody()->findAncestorOpInBlock(*barrier),
623+
mmaOpStage);
624+
}
625+
}
583626
}
584627

585628
bool isSafeToPipeline(ttng::TCGen5MMAScaledOp scaledDot) {
@@ -684,17 +727,33 @@ FailureOr<scf::ForOp> preProcessLoopForTC05MMAPipelining(scf::ForOp forOp,
684727
continue;
685728
}
686729

730+
SmallVector<Operation *> accUses = getDirectAccUses(accLoad);
731+
DominanceInfo domOpInfo(forOp);
732+
Operation *newAccLoadInsertPoint =
733+
findNearestCommonDominator(accUses, domOpInfo);
687734
// Check pipelining and multi-buffering constraints
688-
// 1. If the acc is used by an op in the loop (other than the dot) it
689-
// requires multi-buffering to pipeline, as different stages cannot operate
690-
// on the same buffer.
691-
bool requiresMultiBuffer = !getDirectAccUses(accLoad).empty();
735+
// 1. Really needs multibuffering - if the acc is used unconditionally in
736+
// the loop, or under different conditions. If we cannot multibuffer in this
737+
// case, we may as well not pipeline at all, as we will have to wait after
738+
// the dot in every loop iteration.
739+
scf::IfOp topLevelIf =
740+
newAccLoadInsertPoint
741+
? dyn_cast<scf::IfOp>(forOp.getBody()->findAncestorOpInBlock(
742+
*newAccLoadInsertPoint))
743+
: nullptr;
744+
bool requiresMultiBuffer = accUses.size() > 0 && !topLevelIf;
745+
// If we override the acc in the loop, it is generally hard to handle it
746+
// without multibuffering. We make an exception if it not a physical
747+
// override of a value, but just setting a flag that acc is not used. In
748+
// this case we don't need different buffer to store init value.
749+
requiresMultiBuffer |=
750+
accOverridePoint.has_value() && !accOverridePoint->isFlag;
692751

693752
// 2. If the acc is not owerwritten in the loop (by op other than the dot),
694753
// it cannot be multi-buffered. This is because the overwrite is the only
695754
// way to initialize next buffer without incurring a copy.
696-
bool canMultiBuffer = accOverridePoint.has_value();
697-
755+
bool canMultiBuffer = accOverridePoint.has_value() &&
756+
!mlir::triton::getDisallowAccMultiBuffer(forOp);
698757
if (requiresMultiBuffer && !canMultiBuffer) {
699758
continue;
700759
}
@@ -703,7 +762,7 @@ FailureOr<scf::ForOp> preProcessLoopForTC05MMAPipelining(scf::ForOp forOp,
703762
.accLoad = accLoad,
704763
.accDef = accOverridePoint,
705764
.yieldArgNo = yieldArgNo,
706-
.accIsMultiBuffered = requiresMultiBuffer};
765+
.accIsMultiBuffered = canMultiBuffer};
707766

708767
builder.setInsertionPoint(forOp);
709768
Value zero = builder.create<arith::ConstantIntOp>(forOp.getLoc(), 0, 32);

python/test/unit/language/test_matmul.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,8 @@ def simple_persistent_kernel(a_ptr, b_ptr, c_ptr, M, N, K, stride_am, stride_ak,
147147
stride_bk, stride_bn, #
148148
stride_cm, stride_cn, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr,
149149
BLOCK_SIZE_K: tl.constexpr, #
150-
GROUP_SIZE_M: tl.constexpr, NUM_SMS: tl.constexpr):
150+
GROUP_SIZE_M: tl.constexpr, NUM_SMS: tl.constexpr,
151+
DISALLOW_ACC_MULTI_BUFFER: tl.constexpr):
151152
start_pid = tl.program_id(axis=0)
152153
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
153154
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
@@ -171,7 +172,7 @@ def simple_persistent_kernel(a_ptr, b_ptr, c_ptr, M, N, K, stride_am, stride_ak,
171172

172173
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
173174

174-
for _ in range(0, k_tiles * tiles_per_SM):
175+
for _ in tl.range(0, k_tiles * tiles_per_SM, disallow_acc_multi_buffer=DISALLOW_ACC_MULTI_BUFFER):
175176
ki = tl.where(ki == k_tiles - 1, 0, ki + 1)
176177
if ki == 0:
177178
tile_id += NUM_SMS
@@ -220,7 +221,8 @@ def simple_persistent_kernel(a_ptr, b_ptr, c_ptr, M, N, K, stride_am, stride_ak,
220221
@pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K", [(128, 128, 16), (64, 128, 32), (32, 32, 32), (256, 128, 16),
221222
(64, 512, 16), (512, 64, 16), (64, 16, 16)])
222223
@pytest.mark.parametrize("NUM_WARPS", [4, 8])
223-
def test_simple_persistent_matmul(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, device):
224+
@pytest.mark.parametrize("DISALLOW_ACC_MULTI_BUFFER", [True, False])
225+
def test_simple_persistent_matmul(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, DISALLOW_ACC_MULTI_BUFFER, device):
224226
M, N, K = 1024, 512, 256
225227
NUM_STAGES = 3
226228
a = torch.randn(M, K, dtype=torch.float16, device=device)
@@ -238,7 +240,8 @@ def test_simple_persistent_matmul(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, device):
238240
b.stride(0), b.stride(1), #
239241
output.stride(0), output.stride(1), #
240242
BLOCK_SIZE_M=BLOCK_M, BLOCK_SIZE_N=BLOCK_N, BLOCK_SIZE_K=BLOCK_K, #
241-
GROUP_SIZE_M=8, NUM_SMS=NUM_SMS, num_stages=NUM_STAGES, num_warps=NUM_WARPS)
243+
GROUP_SIZE_M=8, NUM_SMS=NUM_SMS, DISALLOW_ACC_MULTI_BUFFER=DISALLOW_ACC_MULTI_BUFFER, num_stages=NUM_STAGES,
244+
num_warps=NUM_WARPS)
242245
ref_out = torch.matmul(a.to(torch.float32), b.to(torch.float32)).to(torch.float16)
243246

244247
torch.testing.assert_close(ref_out, output, atol=0.01, rtol=0.01)
@@ -250,8 +253,8 @@ def test_simple_persistent_matmul(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, device):
250253
if (device == "cuda" and torch.cuda.get_device_capability()[0] == 10 and BLOCK_M % 64 == 0 and BLOCK_N % 8 == 0
251254
and BLOCK_N > 16):
252255
ttgir = k.asm["ttgir"]
253-
pattern = (r"ttng.wait_barrier %arg")
254-
assert re.search(pattern, str(ttgir)), "The TTGIR does not match the expected pattern."
256+
pattern = "ttng.wait_barrier %arg"
257+
assert ttgir.count(pattern) > 0, "Expect barrier coming from the previous iteration."
255258

256259

257260
@triton.jit

python/triton/compiler/code_generator.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -998,6 +998,7 @@ def visit_For(self, node):
998998
return
999999
num_stages = None
10001000
loop_unroll_factor = None
1001+
disallow_acc_multi_buffer = False
10011002
flatten = None
10021003
if IteratorClass is language.range:
10031004
iterator = IteratorClass(*iter_args, **iter_kwargs)
@@ -1009,6 +1010,7 @@ def visit_For(self, node):
10091010
step = iterator.step
10101011
num_stages = iterator.num_stages
10111012
loop_unroll_factor = iterator.loop_unroll_factor
1013+
disallow_acc_multi_buffer = iterator.disallow_acc_multi_buffer
10121014
flatten = iterator.flatten
10131015
elif IteratorClass is range:
10141016
# visit iterator arguments
@@ -1084,6 +1086,8 @@ def visit_For(self, node):
10841086
for_op.set_attr("tt.num_stages", self.builder.get_int32_attr(num_stages))
10851087
if loop_unroll_factor is not None:
10861088
for_op.set_attr("tt.loop_unroll_factor", self.builder.get_int32_attr(loop_unroll_factor))
1089+
if disallow_acc_multi_buffer:
1090+
for_op.set_attr("tt.disallow_acc_multi_buffer", self.builder.get_unit_attr())
10871091
if flatten:
10881092
for_op.set_attr("tt.flatten", self.builder.get_unit_attr())
10891093

python/triton/language/core.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2865,12 +2865,15 @@ def kernel(...):
28652865
:param loop_unroll_factor: Tells the Triton IR level loop unroller how many
28662866
times to unroll a for loop that this range is used with. Less than 2 for
28672867
this value implies no unrolling.
2868+
:param disallow_acc_multi_buffer: If true, prevent the accumulator of the dot
2869+
operation in the loop to be multi-buffered, if applicable.
28682870
:param flatten: automatically flatten the loop nest starting at this loop to
28692871
create a single flattened loop. The compiler will try to pipeline the
28702872
flattened loop which can avoid stage stalling.
28712873
"""
28722874

2873-
def __init__(self, arg1, arg2=None, step=None, num_stages=None, loop_unroll_factor=None, flatten=None):
2875+
def __init__(self, arg1, arg2=None, step=None, num_stages=None, loop_unroll_factor=None,
2876+
disallow_acc_multi_buffer=False, flatten=None):
28742877
if step is None:
28752878
self.step = constexpr(1)
28762879
else:
@@ -2883,6 +2886,7 @@ def __init__(self, arg1, arg2=None, step=None, num_stages=None, loop_unroll_fact
28832886
self.end = arg2
28842887
self.num_stages = num_stages
28852888
self.loop_unroll_factor = loop_unroll_factor
2889+
self.disallow_acc_multi_buffer = disallow_acc_multi_buffer
28862890
self.flatten = flatten
28872891

28882892
def __iter__(self):

0 commit comments

Comments
 (0)