Skip to content

Commit d82cfd3

Browse files
authored
[Blackwell] Handle control flow in TMEM allocation (#7698)
todo: write unit test
1 parent d8774e3 commit d82cfd3

File tree

2 files changed

+133
-21
lines changed

2 files changed

+133
-21
lines changed

lib/Dialect/TritonNvidiaGPU/Transforms/TensorMemoryAllocation.cpp

Lines changed: 90 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
#include "mlir/Analysis/Liveness.h"
2+
#include "mlir/Dialect/Arith/IR/Arith.h"
3+
#include "mlir/Interfaces/ControlFlowInterfaces.h"
24
#include "mlir/Support/LogicalResult.h"
35
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
46
#include "mlir/Transforms/Passes.h"
57
#include "triton/Analysis/Allocation.h"
68
#include "triton/Dialect/Triton/IR/Utility.h"
9+
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
10+
#include "triton/Dialect/TritonGPU/IR/Traits.h"
711
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
812
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h"
913
#include "llvm/ADT/EquivalenceClasses.h"
@@ -175,30 +179,92 @@ static TMemChunk allocFirstFit(MemoryBitMap &memoryMap,
175179
return chunk;
176180
}
177181

178-
static Operation *getAlloc(Value value) {
179-
while (true) {
180-
if (auto allocOp = value.getDefiningOp<TMEMAllocOp>())
181-
return allocOp;
182-
if (auto indexOp = value.getDefiningOp<ttg::MemDescIndexOp>()) {
183-
value = indexOp.getSrc();
182+
static SmallVector<Operation *> getAlloc(Value value) {
183+
SmallVector<Operation *> allocs;
184+
DenseSet<Value> seen;
185+
SmallVector<Value> worklist{value};
186+
187+
while (!worklist.empty()) {
188+
Value v = worklist.pop_back_val();
189+
if (!seen.insert(v).second)
184190
continue;
185-
}
186-
if (auto reinterpOp = value.getDefiningOp<ttg::MemDescReinterpretOp>()) {
187-
value = reinterpOp.getSrc();
191+
192+
// Handle block arguments.
193+
if (auto arg = dyn_cast<BlockArgument>(v)) {
194+
Block *block = arg.getOwner();
195+
Operation *parentOp = block->getParentOp();
196+
197+
// Handle block with predecessors.
198+
if (!block->isEntryBlock()) {
199+
for (Block *pred : block->getPredecessors()) {
200+
Operation *predOp = pred->getTerminator();
201+
auto br = dyn_cast<BranchOpInterface>(predOp);
202+
if (!br) {
203+
llvm::report_fatal_error("unhandled branch op: " +
204+
predOp->getName().getStringRef());
205+
}
206+
SmallVector<Attribute> operands(br->getNumOperands());
207+
auto it = llvm::find(br->getSuccessors(), block);
208+
unsigned idx = std::distance(br->getSuccessors().begin(), it);
209+
SuccessorOperands args = br.getSuccessorOperands(idx);
210+
Value operand =
211+
args.getForwardedOperands()[arg.getArgNumber() -
212+
args.getProducedOperandCount()];
213+
worklist.push_back(operand);
214+
}
215+
continue;
216+
}
217+
218+
// Handle region entry arguments.
219+
if (auto wsOp = dyn_cast<ttg::WarpSpecializePartitionsOp>(parentOp)) {
220+
worklist.push_back(
221+
wsOp.getParentOp().getExplicitCaptures()[arg.getArgNumber()]);
222+
} else if (auto forOp = dyn_cast<scf::ForOp>(parentOp)) {
223+
unsigned idx = arg.getArgNumber() - 1;
224+
worklist.push_back(forOp.getYieldedValues()[idx]);
225+
worklist.push_back(forOp.getInits()[idx]);
226+
} else if (auto whileOp = dyn_cast<scf::WhileOp>(parentOp)) {
227+
unsigned idx = arg.getArgNumber();
228+
if (arg.getParentRegion() == &whileOp.getAfter()) {
229+
worklist.push_back(whileOp.getConditionOp().getArgs()[idx]);
230+
} else {
231+
worklist.push_back(whileOp.getYieldedValues()[idx]);
232+
worklist.push_back(whileOp.getInits()[idx]);
233+
}
234+
} else {
235+
llvm::report_fatal_error(
236+
"unhandled parent op when looking for TMEM alloc: " +
237+
parentOp->getName().getStringRef());
238+
}
188239
continue;
189240
}
190-
if (auto slice = value.getDefiningOp<TMEMSubSliceOp>()) {
191-
value = slice.getSrc();
192-
continue;
241+
242+
Operation *defOp = v.getDefiningOp();
243+
unsigned idx = cast<OpResult>(v).getResultNumber();
244+
if (isa<TMEMAllocOp>(defOp)) {
245+
allocs.push_back(defOp);
246+
} else if (defOp->hasTrait<OpTrait::MemDescViewTrait>()) {
247+
worklist.push_back(defOp->getOperand(0));
248+
} else if (auto sliceOp = dyn_cast<TMEMSubSliceOp>(defOp)) {
249+
worklist.push_back(sliceOp.getSrc());
250+
} else if (auto selectOp = dyn_cast<arith::SelectOp>(defOp)) {
251+
worklist.push_back(selectOp.getTrueValue());
252+
worklist.push_back(selectOp.getFalseValue());
253+
} else if (auto ifOp = dyn_cast<scf::IfOp>(defOp)) {
254+
worklist.push_back(ifOp.thenYield().getOperand(idx));
255+
worklist.push_back(ifOp.elseYield().getOperand(idx));
256+
} else if (auto forOp = dyn_cast<scf::ForOp>(defOp)) {
257+
worklist.push_back(forOp.getYieldedValues()[idx]);
258+
worklist.push_back(forOp.getInits()[idx]);
259+
} else if (auto whileOp = dyn_cast<scf::WhileOp>(defOp)) {
260+
worklist.push_back(whileOp.getConditionOp().getArgs()[idx]);
261+
} else {
262+
llvm::report_fatal_error("unhandled op when looking for TMEM alloc: " +
263+
defOp->getName().getStringRef());
193264
}
194-
auto arg = dyn_cast<BlockArgument>(value);
195-
if (!arg || !isa<triton::gpu::WarpSpecializePartitionsOp>(
196-
arg.getOwner()->getParentOp()))
197-
llvm::report_fatal_error("expected to find a TMEM alloc op");
198-
auto partitions = cast<triton::gpu::WarpSpecializePartitionsOp>(
199-
arg.getOwner()->getParentOp());
200-
value = partitions.getParentOp().getExplicitCaptures()[arg.getArgNumber()];
201265
}
266+
267+
return allocs;
202268
}
203269

204270
class RowIdConstraints {
@@ -245,8 +311,11 @@ allocateTMem(Operation *parentOp,
245311
if (allocSize.numRows == 64) {
246312
// HW restriction, the A alloc and accumulator needs to be in the same
247313
// rows.
248-
rowIdConstraints.joinOps(getAlloc(mmaOp.getA()),
249-
getAlloc(mmaOp.getAccumulator()));
314+
SmallVector<Operation *> lhsAllocs = getAlloc(mmaOp.getA());
315+
SmallVector<Operation *> accAllocs = getAlloc(mmaOp.getAccumulator());
316+
for (Operation *lhsAlloc : lhsAllocs)
317+
for (Operation *accAlloc : accAllocs)
318+
rowIdConstraints.joinOps(lhsAlloc, accAlloc);
250319
} else {
251320
// TODO: we need to handle cases where the format is blockM and we
252321
// have multiple blocks.

test/TritonNvidiaGPU/test_tensor_memory_allocation.mlir

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,3 +350,46 @@ tt.func @alloc_warp_specialize_explicit_capture() {
350350
}
351351

352352
}
353+
354+
// -----
355+
356+
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 64, transposed = false, elementBitWidth = 8}>
357+
#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 64, transposed = true, elementBitWidth = 8}>
358+
#shared2 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}>
359+
#tmem = #ttng.tensor_memory_encoding<blockM = 64, blockN = 64, unpacked = true>
360+
#tmem_scales = #ttng.tensor_memory_scales_encoding<>
361+
362+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 65536 : i32} {
363+
364+
// CHECK-LABEL: @mma_lhs_tmem
365+
tt.func @mma_lhs_tmem(
366+
%b: !ttg.memdesc<64x64xf16, #shared1, #ttg.shared_memory>,
367+
%useAcc: i1,
368+
%pred: i1,
369+
%barrier: !ttg.memdesc<1xi64, #shared2, #ttg.shared_memory>,
370+
%barrierPred: i1
371+
) {
372+
// CHECK-COUNT-4: ttng.tmem_alloc {{.*}} tensor_memory_row_offset = 0 : i32
373+
// CHECK-NOT: tensor_memory_row_offset
374+
%a0 = ttng.tmem_alloc : () -> !ttg.memdesc<64x64xf16, #tmem, #ttng.tensor_memory, mutable>
375+
%a1 = ttng.tmem_alloc : () -> !ttg.memdesc<64x64xf16, #tmem, #ttng.tensor_memory, mutable>
376+
%a2 = ttng.tmem_alloc : () -> !ttg.memdesc<64x64xf16, #tmem, #ttng.tensor_memory, mutable>
377+
%c = ttng.tmem_alloc : () -> !ttg.memdesc<64x64xf32, #tmem, #ttng.tensor_memory, mutable>
378+
379+
%a = arith.select %barrierPred, %a0, %a1 : !ttg.memdesc<64x64xf16, #tmem, #ttng.tensor_memory, mutable>
380+
381+
cf.cond_br %barrierPred, ^switch, ^bb1(%a : !ttg.memdesc<64x64xf16, #tmem, #ttng.tensor_memory, mutable>)
382+
383+
^switch:
384+
cf.br ^bb1(%a2 : !ttg.memdesc<64x64xf16, #tmem, #ttng.tensor_memory, mutable>)
385+
386+
^bb1(%lhs: !ttg.memdesc<64x64xf16, #tmem, #ttng.tensor_memory, mutable>):
387+
ttng.tc_gen5_mma %lhs, %b, %c, %useAcc, %pred, %barrier[%barrierPred] {is_async} :
388+
!ttg.memdesc<64x64xf16, #tmem, #ttng.tensor_memory, mutable>,
389+
!ttg.memdesc<64x64xf16, #shared1, #ttg.shared_memory>,
390+
!ttg.memdesc<64x64xf32, #tmem, #ttng.tensor_memory, mutable>,
391+
!ttg.memdesc<1xi64, #shared2, #ttg.shared_memory>
392+
tt.return
393+
}
394+
395+
}

0 commit comments

Comments
 (0)