Skip to content

Commit 4c6349d

Browse files
authored
[WS] fix assign-stage-phase to propagate to for-op control opnds (#8634)
- fixes assign-stage-phase to propagate default partition to for-op control operands if needed, not just for-op itself - track root-partition in aref-tmem-insertion to remove ad-hoc skipping tmem-insertion
1 parent bae3b79 commit 4c6349d

File tree

4 files changed

+81
-16
lines changed

4 files changed

+81
-16
lines changed

test/NVWS/aref-tmem-insertion.mlir

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -558,15 +558,17 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} {
558558
// CHECK-NEXT: aref.create
559559
// CHECK-NEXT: aref.put.enter
560560
%result, %token = ttng.tmem_alloc : () -> (!ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, !ttg.async.token)
561-
scf.for %arg3 = %c0_i32 to %arg0 step %c1_i32 : i32 {
562-
%0 = ttg.local_alloc %arg1 : (tensor<128x64xf16, #blocked1>) -> !ttg.memdesc<128x64xf16, #shared, #smem>
561+
%5 = scf.for %arg3 = %c0_i32 to %arg0 step %c1_i32 iter_args(%arg4 = %token) -> (!ttg.async.token) : i32 {
562+
%0 = ttg.local_alloc %arg1 {ttg.partition = array<i32: 0>} : (tensor<128x64xf16, #blocked1>) -> !ttg.memdesc<128x64xf16, #shared, #smem>
563563
%1 = tt.descriptor_load %arg2[%arg3, %arg3] {ttg.partition = array<i32: 2>} : !tt.tensordesc<tensor<64x128xf16, #shared>> -> tensor<64x128xf16, #blocked1>
564564
%2 = arith.addf %1, %1 {ttg.partition = array<i32: 0>} : tensor<64x128xf16, #blocked1>
565565
%3 = ttg.local_alloc %2 {ttg.partition = array<i32: 0>} : (tensor<64x128xf16, #blocked1>) -> !ttg.memdesc<64x128xf16, #shared, #smem>
566566
// CHECK: aref.buffer
567-
%4 = ttng.tc_gen5_mma %0, %3, %result[%token], %true, %true {ttg.partition = array<i32: 1>} : !ttg.memdesc<128x64xf16, #shared, #smem>, !ttg.memdesc<64x128xf16, #shared, #smem>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>
568-
} {tt.num_stages = 2 : i32, tt.warp_specialize, ttg.partition.stages = [0 : i32, 1 : i32, 0 : i32], ttg.warp_specialize.tag = 18 : i32}
567+
%4 = ttng.tc_gen5_mma %0, %3, %result[%arg4], %true, %true {ttg.partition = array<i32: 1>} : !ttg.memdesc<128x64xf16, #shared, #smem>, !ttg.memdesc<64x128xf16, #shared, #smem>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>
568+
scf.yield %4 : !ttg.async.token
569+
} {ttg.partition = array<i32: 0, 1, 2>, ttg.partition.outputs = [array<i32: 1>], tt.num_stages = 2 : i32, tt.warp_specialize, ttg.partition.stages = [0 : i32, 1 : i32, 0 : i32], ttg.warp_specialize.tag = 18 : i32}
569570
// CHECK: aref.put.exit
571+
ttng.tmem_load %result[%5] : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked>
570572
tt.return
571573
}
572574

test/NVWS/assign_stage_phase.mlir

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -674,3 +674,55 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
674674
tt.return
675675
}
676676
}
677+
678+
// -----
679+
#blocked = #ttg.blocked<{sizePerThread = [1, 128], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
680+
#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}>
681+
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}>
682+
#smem = #ttg.shared_memory
683+
#tmem = #ttng.tensor_memory_encoding<blockM = 128, blockN = 128, colStride = 1>
684+
module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} {
685+
// CHECK-LABEL: @for_loop_control_operand_ppg
686+
tt.func @for_loop_control_operand_ppg(%lb: i32, %ub: i32, %step: i32, %ptr0: !tt.ptr<i32>) {
687+
%true = arith.constant true
688+
%arefBuf = ttng.tmem_alloc : () -> !ttg.memdesc<1x128x128xf32, #tmem, #ttng.tensor_memory, mutable>
689+
%aref = nvws.aref.create %arefBuf : <[!ttg.memdesc<1x128x128xf32, #tmem, #ttng.tensor_memory, mutable>]>
690+
%_0, %tok = nvws.aref.put.enter %aref : <[!ttg.memdesc<1x128x128xf32, #tmem, #ttng.tensor_memory, mutable>]> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, !ttg.async.token
691+
// CHECK: put.enter
692+
// CHECK-NEXT: [[RET:%.*]]:5 = scf.for
693+
%tok0 = scf.for %iv0 = %lb to %ub step %step iter_args(%tok1 = %tok) -> (!ttg.async.token) : i32 {
694+
// CHECK-NEXT: tt.addptr {{.*}} {ttg.partition = array<i32: 0, 1, 2>}
695+
// CHECK-NEXT: tt.load {{.*}} {ttg.partition = array<i32: 0, 1, 2>}
696+
// CHECK-NEXT: "lb1"({{.*}}) {ttg.partition = array<i32: 0, 1, 2>}
697+
// CHECK-NEXT: "step1"({{.*}}) {ttg.partition = array<i32: 0, 1, 2>}
698+
%ptrub = tt.addptr %ptr0, %iv0 {ttg.partition = array<i32: 1, 2>} : !tt.ptr<i32>, i32
699+
%ub1 = tt.load %ptrub {ttg.partition = array<i32: 1, 2>} : !tt.ptr<i32>
700+
%lb1 = "lb1"(%iv0) {ttg.partition = array<i32: 1, 2>} : (i32) -> i32
701+
%step1 = "step1"(%iv0) {ttg.partition = array<i32: 1, 2>} : (i32) -> i32
702+
// CHECK-NEXT: [[RET1:%.*]]:3 = scf.for
703+
%tok5 = scf.for %iv = %lb1 to %ub1 step %step1 iter_args(%tok2 = %tok1) -> (!ttg.async.token) : i32 {
704+
%sA = "load1"(%iv) {ttg.partition = array<i32: 1>} : (i32) -> !ttg.memdesc<128x64xf32, #shared, #smem>
705+
%sB = "load2"(%iv) {ttg.partition = array<i32: 1>} : (i32) -> !ttg.memdesc<64x128xf32, #shared, #smem>
706+
%buf = nvws.aref.buffer %aref, %tok2 {ttg.partition = array<i32: 2>} : <[!ttg.memdesc<1x128x128xf32, #tmem, #ttng.tensor_memory, mutable>]>, !ttg.async.token -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>
707+
ttng.tc_gen5_mma %sA, %sB, %buf, %true, %true {ttg.partition = array<i32: 2>} : !ttg.memdesc<128x64xf32, #shared, #smem>, !ttg.memdesc<64x128xf32, #shared, #smem>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>
708+
scf.yield {ttg.partition = array<i32: 1, 2>} %tok2 : !ttg.async.token
709+
} {ttg.partition = array<i32: 1, 2>, ttg.partition.outputs = [array<i32: 2>]}
710+
// CHECK: scf.yield
711+
// CHECK-NEXT: {ttg.partition = array<i32: 0, 1, 2>, ttg.partition.outputs = [array<i32: 2>, array<i32: 0, 2>, array<i32: 2>]}
712+
// CHECK-NEXT: nvws.aref.put.exit {{.*}}[[[RET1]]#1]
713+
nvws.aref.put.exit %aref, %tok5 [#nvws.async_op<tc5mma>] {ttg.partition = array<i32: 2>} : <[!ttg.memdesc<1x128x128xf32, #tmem, #ttng.tensor_memory, mutable>]>, !ttg.async.token
714+
%_1, %token_2 = nvws.aref.get.enter %aref {ttg.partition = array<i32: 1>} : <[!ttg.memdesc<1x128x128xf32, #tmem, #ttng.tensor_memory, mutable>]> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, !ttg.async.token
715+
nvws.aref.get.exit %aref, %token_2 [#nvws.async_op<none>] {ttg.partition = array<i32: 1>} : <[!ttg.memdesc<1x128x128xf32, #tmem, #ttng.tensor_memory, mutable>]>, !ttg.async.token
716+
%buf1, %tok6 = nvws.aref.put.enter %aref {ttg.partition = array<i32: 2>} : <[!ttg.memdesc<1x128x128xf32, #tmem, #ttng.tensor_memory, mutable>]> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, !ttg.async.token
717+
// CHECK: aref.put.enter
718+
// CHECK-NEXT: scf.yield
719+
scf.yield {ttg.partition = array<i32: 1, 2>} %tok6 : !ttg.async.token
720+
// CHECK-NEXT: {tt.warp_specialize, ttg.partition = array<i32: 0, 1, 2>, ttg.partition.outputs = [array<i32: 2>, array<i32: 0, 2>, array<i32: 2>, array<i32: 0, 1>, array<i32: 0, 1>]}
721+
} {tt.warp_specialize, ttg.partition = array<i32: 1, 2>, ttg.partition.outputs = [array<i32: 2>]}
722+
// CHECK-NEXT: aref.put.exit {{.*}}[[[RET]]#1]
723+
nvws.aref.put.exit %aref, %tok0 [#nvws.async_op<tc5mma>] : <[!ttg.memdesc<1x128x128xf32, #tmem, #ttng.tensor_memory, mutable>]>, !ttg.async.token
724+
%_2, %token_2 = nvws.aref.get.enter %aref : <[!ttg.memdesc<1x128x128xf32, #tmem, #ttng.tensor_memory, mutable>]> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, !ttg.async.token
725+
nvws.aref.get.exit %aref, %token_2 [#nvws.async_op<none>] : <[!ttg.memdesc<1x128x128xf32, #tmem, #ttng.tensor_memory, mutable>]>, !ttg.async.token
726+
tt.return
727+
}
728+
}

third_party/nvidia/lib/Dialect/NVWS/Transforms/AssignStagePhase.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -465,6 +465,11 @@ void visitBackwardSlice(scf::ForOp wsLoop, Value value,
465465
visitBackwardSlice(wsLoop,
466466
forOp.getBody()->getTerminator()->getOperand(*pos),
467467
callback, visited);
468+
// visit control operands of for-op
469+
for (int idx = 0; idx < forOp.getNumControlOperands(); ++idx) {
470+
auto control = forOp.getOperand(idx);
471+
visitBackwardSlice(wsLoop, control, callback, visited);
472+
}
468473
}
469474
} else if (wsLoop.getBody()->findAncestorOpInBlock(*defOp)) {
470475
callback(defOp);

third_party/nvidia/lib/Dialect/NVWS/Transforms/InsertTmemAref.cpp

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -246,23 +246,27 @@ struct TmemAccessDag {
246246
return accessDag;
247247
}
248248

249-
std::set<PartitionId> collectPartitions(Node *node) {
249+
std::pair<bool, std::set<PartitionId>> collectPartitions(Node *node) {
250250
std::set<PartitionId> partitions;
251+
bool hasRootPartition = false;
251252
if (node->partitionId)
252253
partitions.insert(*node->partitionId);
253254

254255
while (node->user) {
255256
node = node->user.get();
256257
if (node->partitionId)
257258
partitions.insert(*node->partitionId);
259+
else
260+
hasRootPartition = true;
258261
for (auto &subDag : node->subDags) {
259262
if (subDag) {
260-
auto ps = collectPartitions(subDag.get());
263+
auto [rootPartition, ps] = collectPartitions(subDag.get());
264+
hasRootPartition = hasRootPartition || rootPartition;
261265
partitions.insert(ps.begin(), ps.end());
262266
}
263267
}
264268
}
265-
return partitions;
269+
return {hasRootPartition, partitions};
266270
};
267271

268272
void printNode(Node *node, int indent, llvm::raw_ostream &os) {
@@ -273,20 +277,23 @@ struct TmemAccessDag {
273277
}
274278
std::set<PartitionId> partitions;
275279
os << "|- [" << node->op << "]";
280+
bool hasRootPartition = false;
276281
if (node->partitionId)
277282
partitions.insert(*node->partitionId);
283+
else
284+
hasRootPartition = true;
278285
if (node->op) {
279286
os << node->op->getName().getStringRef() << " ";
280287
if (auto tmemAlloc = dyn_cast<TMEMAllocOp>(node->op)) {
281288
if (tmemAlloc.getSrc()) {
282289
os << " %src ";
283290
} else {
284-
partitions = collectPartitions(node);
291+
std::tie(hasRootPartition, partitions) = collectPartitions(node);
285292
}
286293
}
287294
os << " ";
288295
}
289-
os << "[";
296+
os << "[" << (hasRootPartition ? "root" : "") << "]";
290297
for (auto partition : partitions) {
291298
os << " @" << partition << " ";
292299
}
@@ -526,10 +533,6 @@ LogicalResult insertTmemAref(TmemAccessDag &accessDag) {
526533
auto rootNode = accessDag.getRootNode();
527534
auto allocOp = cast<TMEMAllocOp>(rootNode->op);
528535

529-
// do nothing for alloc with src, whose user is in the same partition
530-
if (allocOp.getSrc() && rootNode->user->partitionId == rootNode->partitionId)
531-
return success();
532-
533536
std::optional<bool> isMultiStaged;
534537
for (auto user : allocOp.getResult().getUsers()) {
535538
if (auto mmaOp = dyn_cast<MMAv5OpInterface>(user)) {
@@ -613,7 +616,8 @@ LogicalResult insertTmemAref(TmemAccessDag &accessDag) {
613616
// the corresponding partition to prevent deadlocks. This is necessary
614617
// because if we're inside an outer loop, re-entering the loop without
615618
// posting a matching GET operation for the PUT would cause the dead-lock.
616-
auto partitions = accessDag.collectPartitions(accessDag.getRootNode());
619+
auto [hasRootPartition, partitions] =
620+
accessDag.collectPartitions(accessDag.getRootNode());
617621
std::optional<int> otherPartitionId;
618622
// since we only have two partition, we just pick the other partition for
619623
// get
@@ -747,9 +751,11 @@ LogicalResult runOnFunction(triton::FuncOp funcOp) {
747751

748752
for (auto &accessDag : tmemDags) {
749753
LLVM_DEBUG({ accessDag.printDag(llvm::dbgs()); });
750-
auto partitions = accessDag.collectPartitions(accessDag.getRootNode());
754+
auto [hasRootPartition, partitions] =
755+
accessDag.collectPartitions(accessDag.getRootNode());
751756
assert(partitions.size() <= 2 && "expecting at most 2 partitions");
752-
if (!partitions.empty())
757+
auto totalOwners = hasRootPartition + partitions.size();
758+
if (totalOwners > 1)
753759
if (failed(insertTmemAref(accessDag)))
754760
return failure();
755761
}

0 commit comments

Comments
 (0)