@@ -67,26 +67,16 @@ class AssignLoadLatencies {
6767 ModuleOp moduleOp = forOp->getParentOfType <ModuleOp>();
6868 tt::ModuleAxisInfoAnalysis axisInfoAnalysis (moduleOp);
6969
70- llvm::MapVector<Operation *, int > loadOpToIndLevel =
71- loadOpsToIndirectionLevel (forOp, pipelineWithoutDot, axisInfoAnalysis);
70+ llvm::MapVector<Operation *, std::pair<int , Operation *>> loadOpToIndLevel =
71+ loadOpsToIndirectionLevel (forOp, pipelineWithoutDot, axisInfoAnalysis,
72+ numStages);
7273 if (loadOpToIndLevel.empty ())
7374 return ;
7475
75- // We assume loads with different dist are assigned to different stages.
76- // If numStages is 2, we will have no stage available for indirect loads
77- // with dist >= 1. In general, when dist is equal to numStages - 1, we
78- // should not pipeline it.
79- for (auto iter = loadOpToIndLevel.begin ();
80- iter != loadOpToIndLevel.end ();) {
81- if (iter->second >= numStages - 1 )
82- iter = loadOpToIndLevel.erase (iter);
83- else
84- ++iter;
85- }
86-
8776 // Calculate the stage distance between applicable loads.
88- auto vals = llvm::make_second_range (loadOpToIndLevel);
89- int maxIndirectionLevel = vals.empty () ? 0 : *llvm::max_element (vals);
77+ int maxIndirectionLevel = 0 ;
78+ for (auto &[loadOp, info] : loadOpToIndLevel)
79+ maxIndirectionLevel = std::max (maxIndirectionLevel, info.first );
9080 unsigned loadLatency = (numStages - 1 ) / (maxIndirectionLevel + 1 );
9181
9282 for (auto [loadOp, dist] : loadOpToIndLevel) {
@@ -99,17 +89,20 @@ class AssignLoadLatencies {
9989 int numStages;
10090 DenseMap<Operation *, int > &opLatency;
10191
102- bool canHaveSharedEncoding (tt::LoadOp op) {
92+ public:
93+ static bool canHaveSharedEncoding (tt::LoadOp op) {
10394 // If used by an user with DotOp encoding, all the uses must be compatible.
10495 bool incompatible = false ;
10596 getSharedEncIfAllUsersAreDotEnc (op.getResult (), incompatible);
10697 return !incompatible;
10798 }
10899
109- bool isPipeliningBeneficial (Operation *op, Operation *finalUser,
110- tt::ModuleAxisInfoAnalysis &axisInfoAnalysis) {
100+ static bool
101+ isPipeliningBeneficial (Operation *op, Operation *finalUser,
102+ tt::ModuleAxisInfoAnalysis &axisInfoAnalysis,
103+ bool filterSmall) {
111104 if (auto loadOp = dyn_cast<tt::LoadOp>(op)) {
112- if (!canBeConvertedToAsyncLoad (loadOp, axisInfoAnalysis)) {
105+ if (filterSmall && !canBeConvertedToAsyncLoad (loadOp, axisInfoAnalysis)) {
113106 LDBG (" Load " << *loadOp << " is too small for pipelining" );
114107 return false ;
115108 }
@@ -145,90 +138,14 @@ class AssignLoadLatencies {
145138 if (localAllocEnc) {
146139 auto registerTy = cast<RankedTensorType>(op->getResultTypes ()[0 ]);
147140 auto vecBytes = getCopyVecBytes (registerTy, localAllocEnc);
148- if (vecBytes < 4 ) {
141+ if (filterSmall && vecBytes < 4 ) {
149142 // At least 4 bytes need to be consecutive for cp.async
150143 return false ;
151144 }
152145 }
153146
154147 return true ;
155148 }
156-
157- // Create a map from load ops to their indirection level and the
158- // final use of the load op (another load op, or a dot op).
159- // Indirection level is "0" for the load op directly used by the dot op,
160- // "1" for the load op used by the load op used by the dot op, and so on.
161- llvm::MapVector<Operation *, int >
162- loadOpsToIndirectionLevel (scf::ForOp forOp, bool pipelineWithoutDot,
163- tt::ModuleAxisInfoAnalysis &axisInfoAnalysis) {
164- llvm::MapVector<Operation *, int > loadOpToIndLevel;
165- DenseSet<Operation *> seen;
166- DenseSet<Operation *> excluded;
167-
168- std::function<void (Operation *, Operation *, int )> dfs =
169- [&](Operation *op, Operation *finalUser, int distance) {
170- if (!seen.insert (op).second || excluded.count (op))
171- return ;
172- if (isa<tt::LoadOp, tt::DescriptorLoadOp, tt::DescriptorGatherOp>(
173- op)) {
174- if (!isPipeliningBeneficial (op, finalUser, axisInfoAnalysis))
175- return ;
176- if (loadOpToIndLevel.count (op)) {
177- int level = loadOpToIndLevel[op];
178- if (level != distance) {
179- // If we have multiple uses at different distances, we don't
180- // know which one to pick.
181- LDBG (" Load " << *op
182- << " has multiple uses at different distances:"
183- << level << " and " << distance);
184- loadOpToIndLevel.erase (op);
185- excluded.insert (op);
186- return ;
187- }
188- } else {
189- LDBG (" Load " << *op << " considered for pipelining with distance "
190- << distance);
191- loadOpToIndLevel[op] = distance;
192- }
193- finalUser = op;
194- distance++;
195- }
196- for (Value operand : getNestedOperands (op)) {
197- if (isa<mlir::triton::DotOpInterface>(op)) {
198- // Heuristic: only pipeline A and B operands of the dot op.
199- if (operand == op->getOperand (2 ))
200- continue ;
201- }
202- Value v = operand;
203- Operation *defOp = v.getDefiningOp ();
204- if (defOp && defOp->getBlock () == op->getBlock ()) {
205- dfs (defOp, finalUser, distance);
206- }
207- }
208- };
209-
210- bool seenDot = false ;
211- for (Operation &op : forOp.getBody ()->without_terminator ()) {
212- // Arbitrary heuristic. TMEMStoreOp is included to keep logic consistent
213- // with legacy code when we weren't hoisting tmem allocas.
214- if (!isa<mlir::triton::DotOpInterface, ttng::TMEMStoreOp>(op))
215- continue ;
216- seenDot = true ;
217- seen.clear ();
218- dfs (&op, &op, 0 );
219- }
220-
221- // If the loop has numStages attribute, also consider pipelining other loads
222- // that are not directly used by dot ops.
223- if (pipelineWithoutDot && !seenDot) {
224- for (Operation &op : forOp.getBody ()->without_terminator ()) {
225- if (!isa<tt::LoadOp, tt::DescriptorLoadOp, tt::DescriptorGatherOp>(op))
226- dfs (&op, &op, 0 );
227- }
228- }
229-
230- return loadOpToIndLevel;
231- }
232149};
233150
234151class AssignMMALatencies {
@@ -335,6 +252,94 @@ void assignLatencies(ModuleOp moduleOp, int defaultNumStages) {
335252
336253} // namespace
337254
255+ // Create a map from load ops to their indirection level and the
256+ // final use of the load op (another load op, or a dot op).
257+ // Indirection level is "0" for the load op directly used by the dot op,
258+ // "1" for the load op used by the load op used by the dot op, and so on.
259+ llvm::MapVector<Operation *, std::pair<int , Operation *>>
260+ loadOpsToIndirectionLevel (scf::ForOp forOp, bool pipelineWithoutDot,
261+ tt::ModuleAxisInfoAnalysis &axisInfoAnalysis,
262+ int numStages, bool filterSmall) {
263+ llvm::MapVector<Operation *, std::pair<int , Operation *>> loadOpToIndLevel;
264+ DenseSet<Operation *> seen;
265+ DenseSet<Operation *> excluded;
266+
267+ std::function<void (Operation *, Operation *, int )> dfs =
268+ [&](Operation *op, Operation *finalUser, int distance) {
269+ if (!seen.insert (op).second || excluded.count (op))
270+ return ;
271+ if (isa<tt::LoadOp, tt::DescriptorLoadOp, tt::DescriptorGatherOp>(op)) {
272+ if (!AssignLoadLatencies::isPipeliningBeneficial (
273+ op, finalUser, axisInfoAnalysis, filterSmall))
274+ return ;
275+ if (loadOpToIndLevel.count (op)) {
276+ int level = loadOpToIndLevel[op].first ;
277+ if (level != distance) {
278+ // If we have multiple uses at different distances, we don't
279+ // know which one to pick.
280+ LDBG (" Load " << *op
281+ << " has multiple uses at different distances:"
282+ << level << " and " << distance);
283+ loadOpToIndLevel.erase (op);
284+ excluded.insert (op);
285+ return ;
286+ }
287+ } else {
288+ LDBG (" Load " << *op << " considered for pipelining with distance "
289+ << distance);
290+ loadOpToIndLevel[op] = {distance, finalUser};
291+ }
292+ finalUser = op;
293+ distance++;
294+ }
295+ for (Value operand : getNestedOperands (op)) {
296+ if (isa<mlir::triton::DotOpInterface>(op)) {
297+ // Heuristic: only pipeline A and B operands of the dot op.
298+ if (operand == op->getOperand (2 ))
299+ continue ;
300+ }
301+ Value v = operand;
302+ Operation *defOp = v.getDefiningOp ();
303+ if (defOp && defOp->getBlock () == op->getBlock ()) {
304+ dfs (defOp, finalUser, distance);
305+ }
306+ }
307+ };
308+
309+ bool seenDot = false ;
310+ for (Operation &op : forOp.getBody ()->without_terminator ()) {
311+ // Arbitrary heuristic. TMEMStoreOp is included to keep logic consistent
312+ // with legacy code when we weren't hoisting tmem allocas.
313+ if (!isa<mlir::triton::DotOpInterface, ttng::TMEMStoreOp>(op))
314+ continue ;
315+ seenDot = true ;
316+ seen.clear ();
317+ dfs (&op, &op, 0 );
318+ }
319+
320+ // If the loop has numStages attribute, also consider pipelining other loads
321+ // that are not directly used by dot ops.
322+ if (pipelineWithoutDot && !seenDot) {
323+ for (Operation &op : forOp.getBody ()->without_terminator ()) {
324+ if (!isa<tt::LoadOp, tt::DescriptorLoadOp, tt::DescriptorGatherOp>(op))
325+ dfs (&op, &op, 0 );
326+ }
327+ }
328+
329+ // We assume loads with different dist are assigned to different stages.
330+ // If numStages is 2, we will have no stage available for indirect loads
331+ // with dist >= 1. In general, when dist is equal to numStages - 1, we
332+ // should not pipeline it.
333+ for (auto iter = loadOpToIndLevel.begin (); iter != loadOpToIndLevel.end ();) {
334+ if (iter->second .first >= numStages - 1 )
335+ iter = loadOpToIndLevel.erase (iter);
336+ else
337+ ++iter;
338+ }
339+
340+ return loadOpToIndLevel;
341+ }
342+
338343// ===----------------------------------------------------------------------===//
339344// Pass Definition
340345// ===----------------------------------------------------------------------===//
0 commit comments