Skip to content

Commit 0946015

Browse files
authored
[AMD] Count vmcnt instructions for AsyncWait (#6426)
Adds `UpdateAsyncWaitCountPass` to adjusts the wait counts of `AsyncWait` ops to reflect the number of interleaved direct to lds assembly instructions. The LLVM backend cannot infer the dependency between the `AsyncCopies` and the `local_reads` so we emit it from Triton as we have the dependency information via tracing the `AsyncToken`. The pass ignores global/buffer loads because the actual number of assembly instructions is determined by the LLVM backend. Note that an underestimation does only affect performance but not correctness. `findMinPathCountInDefChain` is in separate file because we might reuse it for combining `AsyncWaits` in the `StreamPipeliner`.
1 parent e79e08e commit 0946015

File tree

11 files changed

+670
-1
lines changed

11 files changed

+670
-1
lines changed

bin/RegisterTritonDialects.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
7373
mlir::registerTritonAMDGPUConvertToBufferOps();
7474
mlir::registerTritonAMDGPUInThreadTranspose();
7575
mlir::registerTritonAMDGPUCoalesceAsyncCopy();
76+
mlir::registerTritonAMDGPUUpdateAsyncWaitCount();
7677
mlir::triton::registerTritonAMDGPUInsertInstructionSchedHints();
7778
mlir::triton::registerTritonAMDGPULowerInstructionSchedHints();
7879
mlir::registerTritonAMDFoldTrueCmpI();

test/TritonGPU/amd/amd-update-async-wait-count.mlir

Lines changed: 371 additions & 0 deletions
Large diffs are not rendered by default.

third_party/amd/backend/compiler.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,8 @@ def make_ttgir(mod, metadata, options):
277277
passes.common.add_canonicalizer(pm)
278278
passes.common.add_cse(pm)
279279
passes.common.add_symbol_dce(pm)
280+
if use_async_copy:
281+
amd.passes.ttgpuir.add_update_async_wait_count(pm, options.arch)
280282
pm.run(mod)
281283
return mod
282284

third_party/amd/include/TritonAMDGPUTransforms/Passes.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ std::unique_ptr<Pass> createTritonAMDGPUInThreadTransposePass();
4141
std::unique_ptr<Pass>
4242
createTritonAMDGPUCoalesceAsyncCopyPass(std::string archGenName = {});
4343

44+
std::unique_ptr<Pass>
45+
createTritonAMDGPUUpdateAsyncWaitCountPass(std::string archGenName = {});
46+
4447
std::unique_ptr<Pass> createTritonAMDGPUFoldTrueCmpIPass();
4548

4649
/// Generate the code for registering passes.

third_party/amd/include/TritonAMDGPUTransforms/Passes.td

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,26 @@ def TritonAMDGPUCoalesceAsyncCopy: Pass<"tritonamdgpu-coalesce-async-copy", "mli
248248
];
249249
}
250250

251+
def TritonAMDGPUUpdateAsyncWaitCount: Pass<"tritonamdgpu-update-async-wait-count", "mlir::ModuleOp"> {
252+
let summary = "Adjust async wait count to allow prefetching over multiple loop iterations";
253+
254+
let description = [{
255+
GFX9:
256+
LLVM cannot see the dependency across loop iterations between AsyncCopy and local_reads. So we
257+
compute the number of interleaving global memory instructions to emit the correct waitcnt during lowering.
258+
}];
259+
260+
let constructor = "mlir::createTritonAMDGPUUpdateAsyncWaitCountPass()";
261+
262+
let dependentDialects = [];
263+
264+
let options = [
265+
Option<"archGenerationName", "arch-generation-name",
266+
"std::string", /*default=*/"std::string{}",
267+
"GFX generation name of target device.">,
268+
];
269+
}
270+
251271
def TritonAMDFoldTrueCmpI: Pass<"tritonamdgpu-fold-true-cmpi", "mlir::ModuleOp"> {
252272
let summary = "Fold true arith.cmpi to %true";
253273

third_party/amd/lib/TritonAMDGPUTransforms/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ add_triton_library(TritonAMDGPUTransforms
1111
MfmaGroup.cpp
1212
InThreadTranspose.cpp
1313
FoldTrueCmpIOp.cpp
14+
UpdateAsyncWaitCount.cpp
15+
Utility.cpp
1416

1517
DEPENDS
1618
TritonAMDGPUIR

third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -346,11 +346,13 @@ bool StreamPipeliner::createAsyncCopy(tt::LoadOp loadOp, Value alloc,
346346
builder.create<ttg::AsyncCommitGroupOp>(loc, newLoadOp->getResult(0));
347347
ttg::AsyncWaitOp wait =
348348
builder.create<ttg::AsyncWaitOp>(loc, commit->getResult(0), 0);
349-
350349
// We need to place the prefetches (AsyncCopy) after the AsyncWaits which
351350
// create a barrier to ensure all warps are finished reading the shared buffer
352351
// we will write into. This is done by scheduling it as a local_store.
353352
scheduleOp(newLoadOp, SCHED_LOCAL_STORE);
353+
// Place ttg.async_commit_group op next to async load so the later
354+
// UpdateAsyncWaitCount pass can deduce better waitcnts
355+
scheduleOp(commit, SCHED_LOCAL_STORE);
354356

355357
// Create local load which consumes the async token from the AsyncWait
356358
auto sharedLoad =
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
#include "amd/lib/TritonAMDGPUToLLVM/Utility.h"
2+
#include "amd/lib/TritonAMDGPUTransforms/Utility.h"
3+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
4+
#include "llvm/ADT/TypeSwitch.h"
5+
6+
#define GEN_PASS_CLASSES
7+
#include "TritonAMDGPUTransforms/Passes.h"
8+
9+
// This pass updates the waitCount of `AsyncWait` Ops to represent the number of
10+
// inflight async load operation between the async_wait and the definition of
11+
// the AsyncToken, thus allowing to wait only on the dependent async loads
12+
// allowing loads issued after to complete in the future.
13+
// This also means we should never overestimate the value to ensure
14+
// correctness; being conservative and underestimating is fine given that only
15+
// affects performance
16+
// For each async_wait we need to compute the minimum across all AsyncToken
17+
// operands.
18+
// For each token the minimum number of async transaction along it's
19+
// def chain is deduced. A token can be copied when passing in as loop initial
20+
// argument and yielded from a loop body in which case we need to take the
21+
// minimum along both paths.
22+
// We do not exit early if we encounter another async_wait along the def chain
23+
// because the pipeliner will merge redundant waits for us already
24+
25+
using namespace mlir;
26+
namespace tt = triton;
27+
namespace ttg = triton::gpu;
28+
29+
// Returns the number of individual async load memory transactions when copy
30+
// data from the given |srcTy| in global memory to the given |dstTy| in shared
31+
// memory.
32+
int getNumberOfLoadInstructions(RankedTensorType srcTy,
33+
ttg::MemDescType dstTy) {
34+
auto shape = srcTy.getShape();
35+
LinearLayout srcLayout = tt::gpu::toLinearLayout(shape, srcTy.getEncoding());
36+
LinearLayout sharedLayout =
37+
tt::gpu::toLinearLayout(shape, dstTy.getEncoding());
38+
LinearLayout srcToSharedLayout = srcLayout.invertAndCompose(sharedLayout);
39+
40+
// On GFX9 we cannot split direct to lds loads into multiple ones because we
41+
// need coalesced writes. So we can divide the number of registers by the
42+
// contiguity to get the number of load instructions.
43+
int contig = srcToSharedLayout.getNumConsecutiveInOut();
44+
int numberOfRegisters = srcToSharedLayout.getInDimSize(
45+
StringAttr::get(srcTy.getContext(), "register"));
46+
int loadInstructionCount = std::max(1, numberOfRegisters / contig);
47+
return loadInstructionCount;
48+
}
49+
50+
// The pipeliner always insert ops following an order of ttg.async_load ->
51+
// [token] -> ttg.async_commit_group -> [token] -> ttg.async_wait. So here we
52+
// scan the operands of ttg.async_commit_group to count the number of issued
53+
// async load intrinsics.
54+
int getNumberOfLoadInstructions(Operation *op) {
55+
if (isa<ttg::AsyncCommitGroupOp>(op)) {
56+
int count = 0;
57+
for (auto token : op->getOperands()) {
58+
auto defOp = token.getDefiningOp();
59+
if (!defOp)
60+
continue;
61+
if (auto copyOp = llvm::dyn_cast<ttg::AsyncCopyGlobalToLocalOp>(defOp)) {
62+
count += getNumberOfLoadInstructions(copyOp.getSrc().getType(),
63+
copyOp.getResult().getType());
64+
} else if (auto copyOp =
65+
llvm::dyn_cast<amdgpu::BufferLoadToLocalOp>(defOp)) {
66+
auto srcTy = cast<RankedTensorType>(LLVM::AMD::getPointerTypeWithShape(
67+
copyOp.getPtr(), copyOp.getOffsets()));
68+
count += getNumberOfLoadInstructions(srcTy, copyOp.getDest().getType());
69+
}
70+
}
71+
return count;
72+
}
73+
if (isa<tt::LoadOp, tt::StoreOp, amdgpu::BufferLoadToLocalOp,
74+
amdgpu::BufferStoreOp, tt::AtomicRMWOp, tt::AtomicCASOp,
75+
amdgpu::BufferAtomicRMWOp>(op)) {
76+
op->emitRemark("Global memory operation between async wait and "
77+
"async_loads. This will hinder the interleaving of memory "
78+
"operations and might impact performance.");
79+
}
80+
return 0;
81+
}
82+
83+
// LLVM cannot infer the dependency between direct to lds (async) loads and
84+
// the local reads between warps in a workgroup. As a workaround we update the
85+
// waitcnt to represent the number of hardware instructions we are
86+
// interleaving with. This allows us to manually emit the waitcnt during
87+
// lowering.
88+
void updateWaitCount(ttg::AsyncWaitOp waitOp, RewriterBase &rewriter) {
89+
int waitCnt = std::numeric_limits<int>::max();
90+
91+
// AsyncWait can await multiple tokens so we get the minimum from all
92+
// tokens
93+
for (auto token : waitOp.getOperands()) {
94+
// Traverse def chain from waitOp to the producer of the token and count
95+
// the minumum number of vmcnt instructions
96+
auto tokenWaitCnt =
97+
deduceMinCountOnDefChain(token, waitOp, [](Operation *op) {
98+
return getNumberOfLoadInstructions(op);
99+
});
100+
waitCnt = std::min(waitCnt, tokenWaitCnt);
101+
}
102+
103+
if (waitCnt == std::numeric_limits<int>::max() || waitOp.getNum() == waitCnt)
104+
return;
105+
106+
rewriter.modifyOpInPlace(waitOp, [&]() { waitOp.setNum(waitCnt); });
107+
}
108+
109+
struct TritonAMDGPUUpdateAsyncWaitCountPass
110+
: public TritonAMDGPUUpdateAsyncWaitCountBase<
111+
TritonAMDGPUUpdateAsyncWaitCountPass> {
112+
TritonAMDGPUUpdateAsyncWaitCountPass(StringRef archGenName) {
113+
this->archGenerationName = archGenName.str();
114+
}
115+
116+
void runOnOperation() override {
117+
tt::AMD::TargetInfo targetInfo(archGenerationName);
118+
if (!targetInfo.isCDNA()) {
119+
return;
120+
}
121+
122+
ModuleOp m = getOperation();
123+
124+
SmallVector<ttg::AsyncWaitOp> waitOps;
125+
getOperation()->walk(
126+
[&](ttg::AsyncWaitOp waitOp) { waitOps.push_back(waitOp); });
127+
128+
for (auto waitOp : waitOps) {
129+
IRRewriter builder(waitOp->getContext());
130+
updateWaitCount(waitOp, builder);
131+
}
132+
}
133+
};
134+
135+
std::unique_ptr<Pass>
136+
mlir::createTritonAMDGPUUpdateAsyncWaitCountPass(std::string archGenName) {
137+
return std::make_unique<TritonAMDGPUUpdateAsyncWaitCountPass>(archGenName);
138+
}
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
#include "Utility.h"
2+
3+
#include "mlir/Dialect/SCF/IR/SCF.h"
4+
5+
#include <limits>
6+
7+
namespace deduceMin {
8+
int deduceMinCountInBlock(Block &block,
9+
const std::function<int(Operation *)> &countFunc);
10+
11+
// Returns the minimum found when accumulating countFunc(op) between begin and
12+
// end (inclusive)
13+
int deduceMinCountBetweeOps(Operation *beginOp, Operation *endOp,
14+
const std::function<int(Operation *)> &countFunc) {
15+
assert(beginOp && endOp);
16+
assert(beginOp == endOp || beginOp->isBeforeInBlock(endOp));
17+
int count = 0;
18+
for (auto op = beginOp; op != endOp; op = op->getNextNode()) {
19+
if (auto ifOp = llvm::dyn_cast<scf::IfOp>(op)) {
20+
assert(!ifOp.getThenRegion().empty() && !ifOp.getElseRegion().empty());
21+
auto minThen =
22+
deduceMinCountInBlock(ifOp.getThenRegion().front(), countFunc);
23+
auto minElse =
24+
deduceMinCountInBlock(ifOp.getElseRegion().front(), countFunc);
25+
count += std::min(minThen, minElse);
26+
} else if (auto forOp = llvm::dyn_cast<scf::ForOp>(op)) {
27+
auto tripCount = constantTripCount(forOp.getLowerBound(),
28+
forOp.getUpperBound(), forOp.getStep())
29+
.value_or(0);
30+
if (tripCount > 0) {
31+
count += tripCount * deduceMinCountInBlock(*forOp.getBody(), countFunc);
32+
}
33+
} else {
34+
count += countFunc(op);
35+
}
36+
}
37+
return count;
38+
}
39+
40+
// Returns the minimum found when accumulating countFunc(op) for all paths
41+
// between the block's start and end op
42+
int deduceMinCountInBlock(Block &block,
43+
const std::function<int(Operation *)> &countFunc) {
44+
if (block.empty())
45+
return 0;
46+
return deduceMinCountBetweeOps(&block.front(), &block.back(), countFunc);
47+
}
48+
} // namespace deduceMin
49+
50+
int deduceMinCountOnDefChain(Value defValue, Operation *consumerOp,
51+
const std::function<int(Operation *)> &countFunc,
52+
int pathSum, int foundMin) {
53+
using namespace deduceMin;
54+
// If the value is not defined in the same region as the consumer we need to
55+
// peel the parent region of consumer until we arrive at value's region
56+
while (consumerOp->getParentRegion() != defValue.getParentRegion()) {
57+
pathSum += deduceMin::deduceMinCountBetweeOps(
58+
&consumerOp->getBlock()->front(), consumerOp, countFunc);
59+
consumerOp = consumerOp->getParentOp();
60+
}
61+
62+
// Break recursion if we arrive at the producer updating the path based on the
63+
// ops between producer and consumer
64+
if (Operation *defOp = defValue.getDefiningOp()) {
65+
pathSum +=
66+
deduceMinCountBetweeOps(defOp->getNextNode(), consumerOp, countFunc);
67+
foundMin = std::min(foundMin, pathSum);
68+
return foundMin;
69+
}
70+
// If value is a loop carried argument (BlockArgument) we need to look at
71+
// initial arguments of the loop and the previous iteration
72+
if (auto arg = mlir::dyn_cast<BlockArgument>(defValue)) {
73+
Block *block = arg.getOwner();
74+
auto forOp = dyn_cast<scf::ForOp>(block->getParentOp());
75+
76+
// Failed to track, return 0 conservatively.
77+
if (!forOp || forOp.getBody()->empty()) {
78+
return 0;
79+
}
80+
81+
Operation *firstOpInLoop = &*forOp.getBody()->begin();
82+
pathSum += deduceMinCountBetweeOps(firstOpInLoop, consumerOp, countFunc);
83+
84+
// Break recursion early if we exceed previous min
85+
if (pathSum >= foundMin)
86+
return foundMin;
87+
88+
Value incomingVal = forOp.getInitArgs()[arg.getArgNumber() - 1];
89+
int countLoopInit = deduceMinCountOnDefChain(incomingVal, forOp, countFunc,
90+
pathSum, foundMin);
91+
92+
Operation *yieldOp = block->getTerminator();
93+
Value prevVal = yieldOp->getOperand(arg.getArgNumber() - 1);
94+
int countPreviousIter = deduceMinCountOnDefChain(
95+
prevVal, yieldOp, countFunc, pathSum, foundMin);
96+
97+
return std::min(std::min(countLoopInit, countPreviousIter), foundMin);
98+
}
99+
100+
// Unsupported value, return 0 conservatively.
101+
return 0;
102+
}
103+
104+
int deduceMinCountOnDefChain(Value defValue, Operation *consumerOp,
105+
llvm::function_ref<int(Operation *)> countFunc) {
106+
return deduceMinCountOnDefChain(defValue, consumerOp, countFunc, 0,
107+
std::numeric_limits<int>::max());
108+
}
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
#ifndef TRITON_THIRD_PARTY_AMD_LIB_TRITONAMDGPUTRANSFORMS_UTILITY_H_
2+
#define TRITON_THIRD_PARTY_AMD_LIB_TRITONAMDGPUTRANSFORMS_UTILITY_H_
3+
4+
#include "mlir/IR/Operation.h"
5+
#include "mlir/IR/Value.h"
6+
7+
using namespace mlir;
8+
9+
// DFS the def chain of 'defValue' starting from 'consumer' and will return the
10+
// minimum found when accumulating countFunc(op) for all non control flow ops
11+
// between value and the consumer. This function will traverse through for loop
12+
// iterations and to the outside of the loop to find all its producers.
13+
// CountOp(Operation*) should return the value to accumulate for the
14+
// operation
15+
// Returns 0 if there is an error traversing the def chain
16+
int deduceMinCountOnDefChain(Value defValue, Operation *consumerOp,
17+
llvm::function_ref<int(Operation *)> countFunc);
18+
19+
#endif

0 commit comments

Comments
 (0)