Skip to content

Commit 0f1e09e

Browse files
authored
[WarpSpec] Implementation of code partitioning (#6746)
The main flow is in WSCodePartition.cpp: - Collect all communication channels between producers and consumers. - ProducerOp reordering - Dataflow multi-buffering: correctly generate bufferIdx and phase for the communication ops (WSBuffer.cpp) - Lowering loads to asynchronous (WSLowerMem.cpp) - Insert communication ops - Separate the function into partitions of warp_specialize according to attributes on ops (WSSpecialize.cpp)
1 parent 11b6747 commit 0f1e09e

File tree

10 files changed

+3426
-1
lines changed

10 files changed

+3426
-1
lines changed

test/Hopper/WarpSpecialization/ws_code_partition.mlir

Lines changed: 262 additions & 0 deletions
Large diffs are not rendered by default.

third_party/nvidia/hopper/include/Transforms/Passes.td

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,4 +46,23 @@ def NVGPUTestWSDataPartition : Pass<"nvgpu-test-ws-data-partition", "mlir::Modul
4646
];
4747
}
4848

49+
def NVGPUTestWSCodePartition: Pass<"nvgpu-test-ws-code-partition", "mlir::ModuleOp"> {
50+
let summary = "test warp specialization code partition";
51+
52+
let description = "This pass generates warp specialized code baed on task id attributes.";
53+
54+
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
55+
"mlir::triton::TritonDialect",
56+
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect",
57+
"mlir::triton::nvws::NVWSDialect"];
58+
let options = [
59+
Option<"numBuffers", "num-buffers",
60+
"int32_t", /*default*/"0",
61+
"number of buffering for producer-consumer">,
62+
Option<"numWarpGroups", "num-warp-groups",
63+
"int32_t", /*default*/"0",
64+
"number of warp groups for warp specialization">
65+
];
66+
}
67+
4968
#endif // NV_TRANSFORMS_PASSES

third_party/nvidia/hopper/lib/Transforms/CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
add_triton_library(NVHopperTransforms
22
WarpSpecialization.cpp
3+
WarpSpecialization/CodePartitionUtility.cpp
4+
WarpSpecialization/WSBuffer.cpp
5+
WarpSpecialization/WSCodePartition.cpp
6+
WarpSpecialization/WSLowerMem.cpp
7+
WarpSpecialization/WSSpecialize.cpp
38
WarpSpecialization/Utility.cpp
49
WarpSpecialization/WSDataPartition.cpp
510
WarpSpecialization/WSTaskPartition.cpp
Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
#include "CodePartitionUtility.h"
2+
#include "mlir/Analysis/SliceAnalysis.h"
3+
#include "mlir/Pass/Pass.h"
4+
#include "mlir/Pass/PassManager.h"
5+
#include "mlir/Transforms/Passes.h"
6+
#include "nvidia/hopper/include/Transforms/Passes.h"
7+
#include <list>
8+
#include <unordered_set>
9+
10+
namespace tt = mlir::triton;
11+
namespace ttg = mlir::triton::gpu;
12+
namespace ttng = ::mlir::triton::nvidia_gpu;
13+
namespace mlir {
14+
15+
#define DEBUG_TYPE "nvgpu-ws-utility"
16+
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
17+
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
18+
19+
// Check to see if op is enclosed under ifOp.
20+
bool enclosing(scf::IfOp ifOp, Operation *op) {
21+
return ifOp->isProperAncestor(op);
22+
}
23+
24+
bool enclosing(scf::ForOp forOp, Operation *op) {
25+
return forOp->isProperAncestor(op);
26+
}
27+
28+
// Check to see if there is no outer loop that is enclosed under ifOp.
29+
bool immediateEnclosing(scf::IfOp ifOp, Operation *subOp) {
30+
auto pOp = subOp->getParentOfType<scf::ForOp>();
31+
if (!pOp)
32+
return true;
33+
return !enclosing(ifOp, pOp.getOperation());
34+
}
35+
36+
// Return number of AccumCnts for the given ctrlOp. We need one for each nested
37+
// region that contains a channel.
38+
unsigned getAccumCnts(Operation *ctrlOp,
39+
const DenseSet<Operation *> &regionsWithChannels) {
40+
unsigned cnt = 0;
41+
LDBG("getAccumCnts: " << ctrlOp);
42+
for (auto *op : regionsWithChannels) {
43+
LDBG("-- getAccumCnts: " << ctrlOp << " regionsWithChannels " << op);
44+
if (ctrlOp == op) {
45+
++cnt;
46+
continue;
47+
}
48+
if (auto forOp = dyn_cast<scf::ForOp>(ctrlOp)) {
49+
if (enclosing(forOp, op))
50+
++cnt;
51+
continue;
52+
}
53+
if (auto ifOp = dyn_cast<scf::IfOp>(ctrlOp)) {
54+
if (enclosing(ifOp, op))
55+
++cnt;
56+
continue;
57+
}
58+
llvm_unreachable("region op other than If/For is not supported");
59+
}
60+
return cnt;
61+
}
62+
63+
// Assume parentForOp has accumCnt for the specified ctrlOp.
64+
unsigned getAccumArgIdx(scf::ForOp parentForOp, Operation *ctrlOp,
65+
const DenseSet<Operation *> &regionsWithChannels) {
66+
// Walk parentForOp in preorder.
67+
unsigned preOrderId = 0, ctrlId = 0;
68+
bool found = false;
69+
parentForOp->walk<WalkOrder::PreOrder>([&](Operation *subOp) {
70+
// This will walk parentForOp.
71+
if (subOp == ctrlOp) {
72+
ctrlId = preOrderId;
73+
found = true;
74+
}
75+
for (auto *op : regionsWithChannels) {
76+
if (op == subOp) {
77+
LDBG("getAccumArgIdx: saw ctrlOp enclosing channel " << subOp);
78+
++preOrderId;
79+
}
80+
}
81+
});
82+
assert(found && "error in getAccumArgIdx");
83+
LDBG("getAccumArgIdx: " << parentForOp.getOperation() << " " << ctrlOp << " "
84+
<< ctrlId);
85+
return ctrlId;
86+
}
87+
88+
// Compute and return the buffer index and phase for a given accumulate count.
89+
std::pair<Value, Value> getBufferIdxAndPhase(OpBuilderWithAsyncTaskIds &builder,
90+
Location loc, Value accumCnt,
91+
unsigned numBuffers) {
92+
Value numBuffersVal =
93+
builder.createWithAsyncTaskIds<arith::ConstantIntOp>(loc, numBuffers, 32);
94+
numBuffersVal = builder.createWithAsyncTaskIds<arith::ExtSIOp>(
95+
loc, builder.getI64Type(), numBuffersVal);
96+
// Calculate accumCnt / numBuffers
97+
// initBufferIdx = accumCnt - accumCnt / numBuffers * numBuffers
98+
// initPhase = (accumCnt / numBuffers) & 1
99+
Value bufferIdx = builder.createWithAsyncTaskIds<arith::DivUIOp>(
100+
loc, accumCnt, numBuffersVal);
101+
Value initBufferIdx = builder.createWithAsyncTaskIds<arith::SubIOp>(
102+
loc, accumCnt,
103+
builder.createWithAsyncTaskIds<arith::MulIOp>(loc, bufferIdx,
104+
numBuffersVal));
105+
initBufferIdx = builder.createWithAsyncTaskIds<arith::TruncIOp>(
106+
loc, builder.getI32Type(), initBufferIdx);
107+
108+
Value one = builder.createWithAsyncTaskIds<arith::ConstantIntOp>(loc, 1, 64);
109+
bufferIdx =
110+
builder.createWithAsyncTaskIds<arith::AndIOp>(loc, bufferIdx, one);
111+
Value initPhase = builder.createWithAsyncTaskIds<arith::TruncIOp>(
112+
loc, builder.getI1Type(), bufferIdx);
113+
return {initBufferIdx, initPhase};
114+
}
115+
116+
// Get the current accumulation count for the given op within its immediate
117+
// scope.
118+
// ForA (accumForA, accumIfA, accumForB, accumIfB)
119+
// IfA (accumIfA, accumForB)
120+
// Channel A --> uses ForA.arg[accumIfA]
121+
// ForB (accumForB)
122+
// Channel B --> uses ForB.arg[accumForB]
123+
// ThenYield ForA.arg[accumIfA] + 1, ForB.res[accumForB]
124+
// ElseYield ForA.arg[accumIfA], ForA.arg[accumForB]
125+
// ForC (accumForC, accumIfB)
126+
// IfB
127+
// Channel C --> uses ForC.arg[accumIfB]
128+
// ThenYield ForC.arg[accumIfB] + 1
129+
// ElseYield ForC.arg[accumIfB]
130+
// Channel D --> uses ForA.arg[accumForA]
131+
Value getAccumCount(OpBuilderWithAsyncTaskIds &builder, Operation *op,
132+
const DenseSet<Operation *> &regionsWithChannels) {
133+
auto parentForOp = op->getParentOfType<scf::ForOp>();
134+
auto *pOp = op->getParentOp();
135+
// Get parentForOp.arg[pOp]
136+
unsigned tSize = parentForOp.getBody()->getArguments().size();
137+
unsigned parentTCnts = getAccumCnts(parentForOp, regionsWithChannels);
138+
unsigned accumArgId = getAccumArgIdx(parentForOp, pOp, regionsWithChannels);
139+
Value accumCnt =
140+
parentForOp.getBody()->getArgument(tSize - parentTCnts + accumArgId);
141+
142+
LDBG("getAccumCount: parentForOp " << parentForOp.getOperation() << " pOp "
143+
<< pOp << " " << tSize << " "
144+
<< parentTCnts << " " << accumArgId);
145+
return accumCnt;
146+
}
147+
148+
void getBufferIdxAndPhase(OpBuilderWithAsyncTaskIds &builder, Operation *op,
149+
unsigned numBuffers,
150+
const DenseSet<Operation *> &regionsWithChannels,
151+
Value &bufferIdx, Value &phase) {
152+
Value accumCnt = getAccumCount(builder, op, regionsWithChannels);
153+
std::tie(bufferIdx, phase) =
154+
getBufferIdxAndPhase(builder, op->getLoc(), accumCnt, numBuffers);
155+
}
156+
157+
Value getBarrierForPipelineStage(OpBuilderWithAsyncTaskIds &builder,
158+
Value barrierAlloc, Value bufferIdx) {
159+
auto context = barrierAlloc.getContext();
160+
Attribute sharedMemorySpace =
161+
triton::gpu::SharedMemorySpaceAttr::get(context);
162+
ttg::MemDescType barrierTy = ttg::MemDescType::get(
163+
{1}, builder.getI64Type(),
164+
cast<ttg::MemDescType>(barrierAlloc.getType()).getEncoding(),
165+
sharedMemorySpace,
166+
/*mutableMemory=*/true);
167+
168+
// Create barrierForTMA from barrierAlloc.
169+
return builder.createWithAsyncTaskIds<ttg::MemDescSubviewOp>(
170+
barrierAlloc.getLoc(), barrierTy, barrierAlloc,
171+
ArrayRef<Value>({bufferIdx}));
172+
}
173+
174+
} // namespace mlir
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
#ifndef NV_DIALECT_HOPPER_TRANSFORMS_CODEPARTITIONUTILITY_H_
2+
#define NV_DIALECT_HOPPER_TRANSFORMS_CODEPARTITIONUTILITY_H_
3+
4+
#include "triton/Dialect/Triton/IR/Dialect.h"
5+
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
6+
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
7+
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
8+
9+
#include "Utility.h"
10+
#include <algorithm>
11+
#include <numeric>
12+
13+
namespace mlir {
14+
15+
namespace tt = mlir::triton;
16+
17+
enum class DataChannelKind { SMEM, TMEM };
18+
19+
struct Channel {
20+
public:
21+
using Relation = std::pair<int, SmallVector<int>>;
22+
23+
Channel(int producer, SmallVector<int> &consumers, Operation *op,
24+
unsigned operandIdx, unsigned numBuffers)
25+
: relation(producer, consumers), op(op), operandIdx(operandIdx),
26+
numBuffers(numBuffers) {}
27+
28+
bool operator==(const Channel &c) {
29+
return relation == c.relation && operandIdx == c.operandIdx && op == c.op;
30+
}
31+
32+
Operation *getDstOp() { return op; }
33+
unsigned getDstOperandIdx() { return operandIdx; }
34+
virtual Value getSrcOperand() { return op->getOperand(operandIdx); }
35+
virtual Operation *getSrcOp() { return getSrcOperand().getDefiningOp(); }
36+
37+
Relation relation; // producer task Id, a list of consumer task Ids
38+
Operation *op;
39+
unsigned operandIdx;
40+
unsigned numBuffers;
41+
DataChannelKind channelKind = DataChannelKind::SMEM;
42+
};
43+
44+
struct CommChannel {
45+
DenseMap<int, Value> tokens;
46+
// Producer barrier is only needed when the producer op itself can update the
47+
// barrier inline, such as the TMA load.
48+
std::optional<Value> producerBarrier;
49+
// Consumer barrier is only needed when the consumer op itself can update the
50+
// barrier inline, such as the TCGen5MMAOp.
51+
DenseMap<int, Value> consumerBarriers;
52+
};
53+
54+
namespace ttng = ::mlir::triton::nvidia_gpu;
55+
namespace triton {
56+
namespace nvidia_gpu {
57+
struct TmemDataChannel : Channel {
58+
ttng::TMEMAllocOp tmemAllocOp;
59+
ttng::TCGen5MMAOp tmemMmaOp;
60+
Operation *tmemProducerOp;
61+
62+
TmemDataChannel(int producer, SmallVector<int> &consumers,
63+
ttng::TMEMAllocOp tmemAllocOp, ttng::TCGen5MMAOp tmemMmaOp,
64+
Operation *tmemLoadOp, unsigned operandIdx,
65+
unsigned numBuffers)
66+
: Channel(producer, consumers, tmemLoadOp, operandIdx, numBuffers),
67+
tmemAllocOp(tmemAllocOp), tmemProducerOp(tmemAllocOp),
68+
tmemMmaOp(tmemMmaOp) {
69+
assert(consumers.size() == 1 &&
70+
"TmemDataChannel must have a single consumer");
71+
channelKind = DataChannelKind::TMEM;
72+
}
73+
74+
ttng::TMEMAllocOp getAllocOp() { return tmemAllocOp; }
75+
ttng::TCGen5MMAOp getMmaOp() { return tmemMmaOp; }
76+
virtual Operation *getSrcOp() { return tmemProducerOp; }
77+
};
78+
} // namespace nvidia_gpu
79+
} // namespace triton
80+
81+
bool enclosing(scf::IfOp ifOp, Operation *op);
82+
bool enclosing(scf::ForOp forOp, Operation *op);
83+
84+
// Return number of AccumCnts for the given ctrlOp. Add a single
85+
// AccumCnt for all channels under opsWithBufferReuse and it will be the
86+
// last AccumCnt.
87+
unsigned getAccumCnts(Operation *ctrlOp,
88+
const DenseSet<Operation *> &regionsWithChannels);
89+
90+
unsigned getAccumArgIdx(scf::ForOp parentForOp, Operation *ctrlOp,
91+
const DenseSet<Operation *> &regionsWithChannels);
92+
93+
SmallVector<Operation *>
94+
getTaskTopRegion(triton::FuncOp funcOp, const SmallVector<Channel *> &channels);
95+
96+
void appendAccumCntsForOps(SmallVector<Operation *> &taskTopOps,
97+
const SmallVector<Channel *> &channels,
98+
DenseSet<Operation *> &regionsWithChannels);
99+
100+
void collectRegionsWithChannels(const SmallVector<Channel *> &channels,
101+
DenseSet<Operation *> &regionsWithChannels);
102+
void insertAsyncCopy(
103+
triton::FuncOp funcOp,
104+
const DenseMap<Channel *, SmallVector<Channel *>>
105+
&channelsGroupedByProducers,
106+
const DenseMap<Channel *, Value> &bufferMap,
107+
DenseMap<Channel *, std::pair<Operation *, Operation *>> &copyOpMap,
108+
DenseSet<Operation *> &regionsWithChannels);
109+
110+
Value getAccumCount(OpBuilderWithAsyncTaskIds &builder, Operation *op,
111+
const DenseSet<Operation *> &regionsWithChannels);
112+
std::pair<Value, Value> getBufferIdxAndPhase(OpBuilderWithAsyncTaskIds &builder,
113+
Location loc, Value accumCnt,
114+
unsigned numBuffers);
115+
void getBufferIdxAndPhase(OpBuilderWithAsyncTaskIds &builder, Operation *op,
116+
unsigned numBuffers,
117+
const DenseSet<Operation *> &regionsWithChannels,
118+
Value &bufferIdx, Value &phase);
119+
120+
Value getBarrierForPipelineStage(OpBuilderWithAsyncTaskIds &builder,
121+
Value barrierAlloc, Value bufferIdx);
122+
123+
Operation *optimizeTMALoads(OpBuilderWithAsyncTaskIds &builder,
124+
SmallVector<tt::DescriptorLoadOp> &tmaLoads,
125+
SmallVector<Value> &buffers, Value barrierAlloc,
126+
Value bufferIdx, Value bufferIdxExtract,
127+
Value phase, Operation *headProducer,
128+
Operation *headConsumer);
129+
void specializeRegion(triton::FuncOp funcOp);
130+
131+
} // namespace mlir
132+
133+
#endif // NV_DIALECT_HOPPER_TRANSFORMS_CODEPARTITIONUTILITY_H_

third_party/nvidia/hopper/lib/Transforms/WarpSpecialization/Utility.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
#ifndef NV_DIALECT_HOPPER_TRANSFORMS_UTILITY_H_
32

43
#include "mlir/IR/Builders.h"

0 commit comments

Comments
 (0)