Skip to content

Commit 34538bc

Browse files
committed
[AMD] improved subviewing for async-copy-local-to-global
1 parent 6a6fb70 commit 34538bc

File tree

1 file changed

+30
-8
lines changed

1 file changed

+30
-8
lines changed

third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -961,6 +961,14 @@ LogicalResult Pingponger::genAsyncCopySlices(OpBuilder &builder) {
961961
return resValue;
962962
};
963963

964+
auto origSubViewType = subView.getType();
965+
auto subViewDescType = ttg::MemDescType::get(
966+
slicedShape, origSubViewType.getElementType(),
967+
origSubViewType.getEncoding(), origSubViewType.getMemorySpace(),
968+
origSubViewType.getMutableMemory(),
969+
subView.getSrc().getType().getShape());
970+
Value subViewSelector = subView.getOffsets().front();
971+
964972
assert(slicedDim != -1);
965973
SmallVector<Value> newCommits;
966974
auto numReps = origShape[slicedDim] / slicedShape[slicedDim];
@@ -973,20 +981,34 @@ LogicalResult Pingponger::genAsyncCopySlices(OpBuilder &builder) {
973981
auto extractedMask = extract(slicedMaskType, newMask, offsetAttr);
974982
auto extractedOther = extract(slicedOtherType, newOther, offsetAttr);
975983

984+
SmallVector<Value> newSubviewOffset = {subViewSelector};
985+
llvm::for_each(offset, [&](auto off) {
986+
newSubviewOffset.push_back(
987+
builder.create<arith::ConstantIntOp>(subView.getLoc(), off, 32));
988+
});
989+
990+
auto newSlicedSubView = builder.create<ttg::MemDescSubviewOp>(
991+
subView.getLoc(), subViewDescType, subView.getSrc(),
992+
newSubviewOffset);
993+
976994
auto newAsyncCopy = builder.create<ttg::AsyncCopyGlobalToLocalOp>(
977-
asyncCopy->getLoc(), extractedSrc, Value{subViews[rep].getResult()},
978-
extractedMask, extractedOther, asyncCopy.getCache(),
979-
asyncCopy.getEvict(), asyncCopy.getIsVolatile());
995+
asyncCopy->getLoc(), extractedSrc,
996+
Value{newSlicedSubView.getResult()}, extractedMask, extractedOther,
997+
asyncCopy.getCache(), asyncCopy.getEvict(),
998+
asyncCopy.getIsVolatile());
980999

9811000
auto newCommit = builder.create<ttg::AsyncCommitGroupOp>(
9821001
asyncCopy->getLoc(), newAsyncCopy.getToken());
9831002

9841003
// propagate all attributes from `mem-view` to the commit token
1004+
newSlicedSubView->setAttrs(subViews[rep]->getAttrs());
9851005
newAsyncCopy->setAttrs(subViews[rep]->getAttrs());
9861006
newCommit->setAttrs(subViews[rep]->getAttrs());
9871007

9881008
newAsyncGroups[rep].push_back(newCommit);
9891009
newCommits.push_back(newCommit);
1010+
1011+
subViews[rep]->erase();
9901012
}
9911013

9921014
auto origCommitGroup = getSingleUserOf<ttg::AsyncCommitGroupOp>(asyncCopy);
@@ -1739,11 +1761,11 @@ void Pingponger::getDotPingponged() {
17391761
LDBG("failed to update forOp signature");
17401762
}
17411763

1742-
if (llvm::succeeded(updateSignature)) {
1743-
if (llvm::failed(adjustRefinedAsyncTokens(builder))) {
1744-
LDBG("failed to update forOp signature");
1745-
}
1746-
}
1764+
// if (llvm::succeeded(updateSignature)) {
1765+
// if (llvm::failed(adjustRefinedAsyncTokens(builder))) {
1766+
// LDBG("failed to update forOp signature");
1767+
// }
1768+
// }
17471769

17481770
forOp->walk([](ttg::AsyncCommitGroupOp groupOp) {
17491771
auto users = groupOp.getResult().getUsers();

0 commit comments

Comments
 (0)