Skip to content

Commit 20f8e72

Browse files
committed
[AMD] improved subviewing for async-copy-local-to-global
1 parent 6a6fb70 commit 20f8e72

File tree

1 file changed

+20
-1
lines changed

1 file changed

+20
-1
lines changed

third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -961,6 +961,13 @@ LogicalResult Pingponger::genAsyncCopySlices(OpBuilder &builder) {
961961
return resValue;
962962
};
963963

964+
965+
auto origSubViewType = subView.getType();
966+
auto subViewDescType = ttg::MemDescType::get(
967+
slicedShape, origSubViewType.getElementType(), origSubViewType.getEncoding(), origSubViewType.getMemorySpace(),
968+
origSubViewType.getMutableMemory(), origSubViewType.getShape());
969+
Value subViewSelector = subView.getOffsets().front();
970+
964971
assert(slicedDim != -1);
965972
SmallVector<Value> newCommits;
966973
auto numReps = origShape[slicedDim] / slicedShape[slicedDim];
@@ -973,20 +980,32 @@ LogicalResult Pingponger::genAsyncCopySlices(OpBuilder &builder) {
973980
auto extractedMask = extract(slicedMaskType, newMask, offsetAttr);
974981
auto extractedOther = extract(slicedOtherType, newOther, offsetAttr);
975982

983+
SmallVector<Value> newSubviewOffset = {subViewSelector};
984+
llvm::for_each(offset, [&](auto off){
985+
newSubviewOffset.push_back(builder.create<arith::ConstantIntOp>(
986+
subView.getLoc(), off, 32));
987+
});
988+
989+
auto newSlicedSubView = builder.create<ttg::MemDescSubviewOp>(
990+
subView.getLoc(), subViewDescType, subView.getSrc(), newSubviewOffset);
991+
976992
auto newAsyncCopy = builder.create<ttg::AsyncCopyGlobalToLocalOp>(
977-
asyncCopy->getLoc(), extractedSrc, Value{subViews[rep].getResult()},
993+
asyncCopy->getLoc(), extractedSrc, Value{newSlicedSubView.getResult()},
978994
extractedMask, extractedOther, asyncCopy.getCache(),
979995
asyncCopy.getEvict(), asyncCopy.getIsVolatile());
980996

981997
auto newCommit = builder.create<ttg::AsyncCommitGroupOp>(
982998
asyncCopy->getLoc(), newAsyncCopy.getToken());
983999

9841000
// propagate all attributes from `mem-view` to the commit token
1001+
newSlicedSubView->setAttrs(subViews[rep]->getAttrs());
9851002
newAsyncCopy->setAttrs(subViews[rep]->getAttrs());
9861003
newCommit->setAttrs(subViews[rep]->getAttrs());
9871004

9881005
newAsyncGroups[rep].push_back(newCommit);
9891006
newCommits.push_back(newCommit);
1007+
1008+
subViews[rep]->erase();
9901009
}
9911010

9921011
auto origCommitGroup = getSingleUserOf<ttg::AsyncCommitGroupOp>(asyncCopy);

0 commit comments

Comments
 (0)