Skip to content

Commit 4b184cc

Browse files
authored
patch workaround by correctly setting stage/cluster attrubtes (#8797)
* patches workaround for loop-scheduler by using stage/cluster from previous tmem access op in the partition to set stage/cluster for put.exit op, and if needed for the follow-up put.enter op
1 parent 046ab0e commit 4b184cc

File tree

2 files changed

+78
-14
lines changed

2 files changed

+78
-14
lines changed

test/NVWS/aref-tmem-insertion.mlir

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -788,3 +788,57 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} {
788788
tt.return
789789
}
790790
}
791+
792+
// -----
793+
794+
#blocked = #ttg.blocked<{sizePerThread = [1, 128], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
795+
#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}>
796+
#blocked2 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
797+
#blocked3 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
798+
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}>
799+
#smem = #ttg.shared_memory
800+
#tmem = #ttng.tensor_memory_encoding<blockM = 128, blockN = 128, colStride = 1>
801+
module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} {
802+
// CHECK-LABEL: @if_split_workaround
803+
tt.func @if_split_workaround(%arg0: !tt.tensordesc<tensor<1x64xf16, #shared>>, %arg1: tensor<64x128x!tt.ptr<f16>, #blocked3> {tt.contiguity = dense<[1, 64]> : tensor<2xi32>, tt.divisibility = dense<16> : tensor<2xi32>}) {
804+
%c0_i32 = arith.constant 0 : i32
805+
%c1_i32 = arith.constant 1 : i32
806+
%true = arith.constant true
807+
%false = arith.constant false
808+
%cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked>
809+
%c32_i32 = arith.constant 32 : i32
810+
%result, %token = ttng.tmem_alloc : () -> (!ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, !ttg.async.token)
811+
%0 = ttng.tmem_store %cst, %result[%token], %true : tensor<128x128xf32, #blocked> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>
812+
// CHECK: scf.for
813+
%1:3 = scf.for %arg2 = %c0_i32 to %c32_i32 step %c1_i32 iter_args(%arg3 = %true, %arg4 = %arg1, %arg5 = %0) -> (i1, tensor<64x128x!tt.ptr<f16>, #blocked3>, !ttg.async.token) : i32 {
814+
%2:3 = "get_offsets"(%arg2) {loop.cluster = 3 : i32, loop.stage = 0 : i32, ttg.partition = array<i32: 1, 2>} : (i32) -> (i32, tensor<64x128xi32, #blocked3>, i32)
815+
%3 = tt.splat %2#0 {loop.cluster = 3 : i32, loop.stage = 0 : i32, ttg.partition = array<i32: 2>} : i32 -> tensor<128xi32, #blocked2>
816+
%4 = tt.descriptor_gather %arg0[%3, %2#2] {loop.cluster = 3 : i32, loop.stage = 0 : i32, ttg.partition = array<i32: 2>} : (!tt.tensordesc<tensor<1x64xf16, #shared>>, tensor<128xi32, #blocked2>, i32) -> tensor<128x64xf16, #blocked1>
817+
%5 = tt.addptr %arg4, %2#1 {loop.cluster = 3 : i32, loop.stage = 1 : i32, tt.constancy = dense<1> : tensor<2xi32>, tt.contiguity = dense<[1, 64]> : tensor<2xi32>, tt.divisibility = dense<16> : tensor<2xi32>, ttg.partition = array<i32: 1>} : tensor<64x128x!tt.ptr<f16>, #blocked3>, tensor<64x128xi32, #blocked3>
818+
%6 = tt.load %5 {loop.cluster = 3 : i32, loop.stage = 1 : i32, ttg.partition = array<i32: 1>} : tensor<64x128x!tt.ptr<f16>, #blocked3>
819+
%7 = ttg.local_alloc %4 {loop.cluster = 2 : i32, loop.stage = 2 : i32, ttg.partition = array<i32: 2>} : (tensor<128x64xf16, #blocked1>) -> !ttg.memdesc<128x64xf16, #shared, #smem>
820+
%8 = ttg.local_alloc %6 {loop.cluster = 2 : i32, loop.stage = 2 : i32, ttg.partition = array<i32: 1>} : (tensor<64x128xf16, #blocked3>) -> !ttg.memdesc<64x128xf16, #shared, #smem>
821+
// CHECK: tc_gen5_mma {{.*}} {loop.cluster = 2 : i32, loop.stage = 2 : i32
822+
%9 = ttng.tc_gen5_mma %7, %8, %result[%arg5], %arg3, %true {loop.cluster = 2 : i32, loop.stage = 2 : i32, tt.self_latency = 1 : i32, ttg.partition = array<i32: 1>} : !ttg.memdesc<128x64xf16, #shared, #smem>, !ttg.memdesc<64x128xf16, #shared, #smem>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>
823+
%10 = arith.cmpi eq, %arg2, %c0_i32 {loop.cluster = 1 : i32, loop.stage = 3 : i32, ttg.partition = array<i32: 0, 1>} : i32
824+
%11 = arith.select %10, %false, %true {loop.cluster = 1 : i32, loop.stage = 3 : i32, ttg.partition = array<i32: 1>} : i1
825+
// CHECK: scf.if
826+
// CHECK-NEXT: put.exit {{.*}} {loop.cluster = 2 : i32, loop.stage = 2 : i32
827+
// CHECK} {loop.cluster = 2 : i32, loop.stage = 2 : i32
828+
// CHECK: scf.if
829+
// CHECK: } {loop.cluster = 4 : i32, loop.stage = 3 : i32
830+
// CHECK: scf.if
831+
// CKECK-NEXT: put.enter {{.*}} {loop.cluster = 2 : i32, loop.stage = 2 : i32
832+
// CHECK: } {loop.cluster = 2 : i32, loop.stage = 2 : i32
833+
%12 = scf.if %10 -> (!ttg.async.token) {
834+
%result_0, %token_1 = ttng.tmem_load %result[%9] {ttg.partition = array<i32: 0>} : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked>
835+
"acc_user"(%result_0) {ttg.partition = array<i32: 0>} : (tensor<128x128xf32, #blocked>) -> ()
836+
scf.yield {ttg.partition = array<i32: 0, 1>} %token_1 : !ttg.async.token
837+
} else {
838+
scf.yield {ttg.partition = array<i32: 0, 1>} %9 : !ttg.async.token
839+
} {loop.cluster = 4 : i32, loop.stage = 3 : i32, ttg.partition = array<i32: 0, 1>, ttg.partition.outputs = [array<i32: 1>]}
840+
scf.yield {ttg.partition = array<i32: 0, 1, 2>} %11, %5, %12 : i1, tensor<64x128x!tt.ptr<f16>, #blocked3>, !ttg.async.token
841+
} {tt.disallow_acc_multi_buffer, tt.num_stages = 3 : i32, tt.scheduled_max_stage = 3 : i32, tt.warp_specialize, ttg.partition = array<i32: 0, 1, 2>, ttg.partition.outputs = [array<i32: 1>, array<i32: 1>, array<i32: 1>], ttg.partition.stages = [0 : i32, 1 : i32, 0 : i32], ttg.warp_specialize.tag = 2 : i32}
842+
tt.return
843+
}
844+
}

third_party/nvidia/lib/Dialect/NVWS/Transforms/InsertTmemAref.cpp

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -402,10 +402,15 @@ struct TMEMAref {
402402
token = op.getToken();
403403
}
404404
partitionId = paritionIdStageCluster.first;
405+
if (partitionId)
406+
stageClusters[*partitionId] = paritionIdStageCluster.second;
405407
buffer = {};
406408
}
407-
void release(OpBuilder &b, Location loc, StageCluster stageCluster) {
409+
void release(OpBuilder &b, Location loc) {
408410
assert(asyncOp);
411+
StageCluster stageCluster;
412+
if (partitionId)
413+
stageCluster = stageClusters[*partitionId];
409414
if (kind == PUT) {
410415
createInto<ArefPutExitOp>(
411416
b, loc, {partitionId, stageCluster}, aref, token,
@@ -447,6 +452,7 @@ struct TMEMAref {
447452
Kind kind;
448453
std::optional<PartitionId> partitionId;
449454
std::optional<AsyncOp> asyncOp;
455+
DenseMap<PartitionId, StageCluster> stageClusters;
450456
};
451457

452458
TmemAccessDag::Node *
@@ -458,25 +464,18 @@ insertTmemArefImpl(TmemAccessDag::Node *node,
458464
if (curPartitionId && node->partitionId != curPartitionId) {
459465
OpBuilder b(node->op);
460466
Operation *prevOp = nullptr;
461-
StageCluster prevStageCluster;
462467
if (node->parent) {
463468
// release right after the last op which owns the tmem
464469
prevOp = node->parent->op;
465470
b.setInsertionPointAfter(prevOp);
466-
prevStageCluster = getStageCluster(prevOp);
467471
} else {
468472
// if we are inside if-stmt or for-stmt subdag and need to change
469473
// ownerhip, release at the top of the block
470474
// the parentDag op would be if-stmt or for-stmt
471475
prevOp = node->parentDag->op;
472476
b.setInsertionPointToStart(node->op->getBlock());
473477
}
474-
if (!node->partitionId) {
475-
// if node->partitionId is not set, it means we are outside ws-region
476-
// reset prevPartitionId and prevStageCluster to defaults
477-
prevStageCluster = {};
478-
}
479-
state.release(b, prevOp->getLoc(), prevStageCluster);
478+
state.release(b, prevOp->getLoc());
480479

481480
// acquire right before op that acquires ownership of tmem
482481
auto curOp = node->op;
@@ -489,6 +488,10 @@ insertTmemArefImpl(TmemAccessDag::Node *node,
489488
curOp = node->parentDag->op;
490489
}
491490
auto stageCluster = getStageCluster(curOp);
491+
// if stage-cluster is empty, use the stage-cluster used from the last op
492+
// that acquired ownership of tmem in a partition
493+
if (!stageCluster && partitionId)
494+
stageCluster = state.stageClusters[*partitionId];
492495
state.acquire(b, curOp->getLoc(), {partitionId, stageCluster});
493496
}
494497

@@ -519,16 +522,22 @@ insertTmemArefImpl(TmemAccessDag::Node *node,
519522

520523
OpBuilder b(node->op);
521524
if (auto tmemLoadOp = dyn_cast<TMEMLoadOp>(node->op)) {
525+
if (auto id = node->partitionId)
526+
state.stageClusters[*id] = getStageCluster(node->op);
522527
tmemLoadOp.getSrcMutable().assign(
523528
state.getBuffer(b, node->partitionId, node->op));
524529
tmemLoadOp.getDepMutable().clear();
525530
tmemLoadOp.getToken().replaceAllUsesWith(state.replToken);
526531
} else if (auto tmemStoreOp = dyn_cast<TMEMStoreOp>(node->op)) {
532+
if (auto id = node->partitionId)
533+
state.stageClusters[*id] = getStageCluster(node->op);
527534
tmemStoreOp.getDstMutable().assign(
528535
state.getBuffer(b, node->partitionId, node->op));
529536
tmemStoreOp.getDepMutable().clear();
530537
tmemStoreOp.getToken().replaceAllUsesWith(state.replToken);
531538
} else if (auto mmaOp = dyn_cast<MMAv5OpInterface>(node->op)) {
539+
if (auto id = node->partitionId)
540+
state.stageClusters[*id] = getStageCluster(node->op);
532541
if (mmaOp.getAccumulator() == state.origBuffer) {
533542
mmaOp.getAccDepMutable().clear();
534543
mmaOp.getToken().replaceAllUsesWith(state.replToken);
@@ -640,10 +649,11 @@ LogicalResult insertTmemAref(TmemAccessDag &accessDag) {
640649
// aref is used outside ws-loop, find the last point in the same block as
641650
// create op to have matching exit
642651
auto op1 = arefOp->getBlock()->findAncestorOpInBlock(*node->op);
652+
if (auto id = node->partitionId)
653+
state.stageClusters[*id] = {};
643654
b.setInsertionPointAfter(op1);
644655
}
645-
stageCluster = getStageCluster(node->op);
646-
state.release(b, node->op->getLoc(), stageCluster);
656+
state.release(b, node->op->getLoc());
647657

648658
if (state.kind == TMEMAref::GET) {
649659
// When the state ends up in a GET operation, we need to acquire and release
@@ -661,7 +671,7 @@ LogicalResult insertTmemAref(TmemAccessDag &accessDag) {
661671
}
662672
}
663673
state.acquire(b, node->op->getLoc(), {otherPartitionId, {}});
664-
state.release(b, node->op->getLoc(), {});
674+
state.release(b, node->op->getLoc());
665675
}
666676

667677
return success();
@@ -751,8 +761,8 @@ void workaroundForLoopScheduler(triton::FuncOp funcOp) {
751761
// patch loop.stage=1
752762
enterIf->setAttrs(ifOp->getAttrs());
753763
exitIf->setAttrs(ifOp->getAttrs());
754-
enterIf->setAttr(kLoopStageAttrName, b.getI32IntegerAttr(1));
755-
exitIf->setAttr(kLoopStageAttrName, b.getI32IntegerAttr(1));
764+
assignStage(b, enterIf, getStageCluster(putEnterOp));
765+
assignStage(b, exitIf, getStageCluster(putExitOp));
756766

757767
SetVector<int> enterExitIds, middleIds;
758768
enterExitIds.insert(1);

0 commit comments

Comments
 (0)