@@ -67,26 +67,16 @@ class AssignLoadLatencies {
67
67
ModuleOp moduleOp = forOp->getParentOfType <ModuleOp>();
68
68
tt::ModuleAxisInfoAnalysis axisInfoAnalysis (moduleOp);
69
69
70
- llvm::MapVector<Operation *, int > loadOpToIndLevel =
71
- loadOpsToIndirectionLevel (forOp, pipelineWithoutDot, axisInfoAnalysis);
70
+ llvm::MapVector<Operation *, std::pair<int , Operation *>> loadOpToIndLevel =
71
+ loadOpsToIndirectionLevel (forOp, pipelineWithoutDot, axisInfoAnalysis,
72
+ numStages);
72
73
if (loadOpToIndLevel.empty ())
73
74
return ;
74
75
75
- // We assume loads with different dist are assigned to different stages.
76
- // If numStages is 2, we will have no stage available for indirect loads
77
- // with dist >= 1. In general, when dist is equal to numStages - 1, we
78
- // should not pipeline it.
79
- for (auto iter = loadOpToIndLevel.begin ();
80
- iter != loadOpToIndLevel.end ();) {
81
- if (iter->second >= numStages - 1 )
82
- iter = loadOpToIndLevel.erase (iter);
83
- else
84
- ++iter;
85
- }
86
-
87
76
// Calculate the stage distance between applicable loads.
88
- auto vals = llvm::make_second_range (loadOpToIndLevel);
89
- int maxIndirectionLevel = vals.empty () ? 0 : *llvm::max_element (vals);
77
+ int maxIndirectionLevel = 0 ;
78
+ for (auto &[loadOp, info] : loadOpToIndLevel)
79
+ maxIndirectionLevel = std::max (maxIndirectionLevel, info.first );
90
80
unsigned loadLatency = (numStages - 1 ) / (maxIndirectionLevel + 1 );
91
81
92
82
for (auto [loadOp, dist] : loadOpToIndLevel) {
@@ -99,17 +89,20 @@ class AssignLoadLatencies {
99
89
int numStages;
100
90
DenseMap<Operation *, int > &opLatency;
101
91
102
- bool canHaveSharedEncoding (tt::LoadOp op) {
92
+ public:
93
+ static bool canHaveSharedEncoding (tt::LoadOp op) {
103
94
// If used by an user with DotOp encoding, all the uses must be compatible.
104
95
bool incompatible = false ;
105
96
getSharedEncIfAllUsersAreDotEnc (op.getResult (), incompatible);
106
97
return !incompatible;
107
98
}
108
99
109
- bool isPipeliningBeneficial (Operation *op, Operation *finalUser,
110
- tt::ModuleAxisInfoAnalysis &axisInfoAnalysis) {
100
+ static bool
101
+ isPipeliningBeneficial (Operation *op, Operation *finalUser,
102
+ tt::ModuleAxisInfoAnalysis &axisInfoAnalysis,
103
+ bool filterSmall) {
111
104
if (auto loadOp = dyn_cast<tt::LoadOp>(op)) {
112
- if (!canBeConvertedToAsyncLoad (loadOp, axisInfoAnalysis)) {
105
+ if (filterSmall && !canBeConvertedToAsyncLoad (loadOp, axisInfoAnalysis)) {
113
106
LDBG (" Load " << *loadOp << " is too small for pipelining" );
114
107
return false ;
115
108
}
@@ -145,90 +138,14 @@ class AssignLoadLatencies {
145
138
if (localAllocEnc) {
146
139
auto registerTy = cast<RankedTensorType>(op->getResultTypes ()[0 ]);
147
140
auto vecBytes = getCopyVecBytes (registerTy, localAllocEnc);
148
- if (vecBytes < 4 ) {
141
+ if (filterSmall && vecBytes < 4 ) {
149
142
// At least 4 bytes need to be consecutive for cp.async
150
143
return false ;
151
144
}
152
145
}
153
146
154
147
return true ;
155
148
}
156
-
157
- // Create a map from load ops to their indirection level and the
158
- // final use of the load op (another load op, or a dot op).
159
- // Indirection level is "0" for the load op directly used by the dot op,
160
- // "1" for the load op used by the load op used by the dot op, and so on.
161
- llvm::MapVector<Operation *, int >
162
- loadOpsToIndirectionLevel (scf::ForOp forOp, bool pipelineWithoutDot,
163
- tt::ModuleAxisInfoAnalysis &axisInfoAnalysis) {
164
- llvm::MapVector<Operation *, int > loadOpToIndLevel;
165
- DenseSet<Operation *> seen;
166
- DenseSet<Operation *> excluded;
167
-
168
- std::function<void (Operation *, Operation *, int )> dfs =
169
- [&](Operation *op, Operation *finalUser, int distance) {
170
- if (!seen.insert (op).second || excluded.count (op))
171
- return ;
172
- if (isa<tt::LoadOp, tt::DescriptorLoadOp, tt::DescriptorGatherOp>(
173
- op)) {
174
- if (!isPipeliningBeneficial (op, finalUser, axisInfoAnalysis))
175
- return ;
176
- if (loadOpToIndLevel.count (op)) {
177
- int level = loadOpToIndLevel[op];
178
- if (level != distance) {
179
- // If we have multiple uses at different distances, we don't
180
- // know which one to pick.
181
- LDBG (" Load " << *op
182
- << " has multiple uses at different distances:"
183
- << level << " and " << distance);
184
- loadOpToIndLevel.erase (op);
185
- excluded.insert (op);
186
- return ;
187
- }
188
- } else {
189
- LDBG (" Load " << *op << " considered for pipelining with distance "
190
- << distance);
191
- loadOpToIndLevel[op] = distance;
192
- }
193
- finalUser = op;
194
- distance++;
195
- }
196
- for (Value operand : getNestedOperands (op)) {
197
- if (isa<mlir::triton::DotOpInterface>(op)) {
198
- // Heuristic: only pipeline A and B operands of the dot op.
199
- if (operand == op->getOperand (2 ))
200
- continue ;
201
- }
202
- Value v = operand;
203
- Operation *defOp = v.getDefiningOp ();
204
- if (defOp && defOp->getBlock () == op->getBlock ()) {
205
- dfs (defOp, finalUser, distance);
206
- }
207
- }
208
- };
209
-
210
- bool seenDot = false ;
211
- for (Operation &op : forOp.getBody ()->without_terminator ()) {
212
- // Arbitrary heuristic. TMEMStoreOp is included to keep logic consistent
213
- // with legacy code when we weren't hoisting tmem allocas.
214
- if (!isa<mlir::triton::DotOpInterface, ttng::TMEMStoreOp>(op))
215
- continue ;
216
- seenDot = true ;
217
- seen.clear ();
218
- dfs (&op, &op, 0 );
219
- }
220
-
221
- // If the loop has numStages attribute, also consider pipelining other loads
222
- // that are not directly used by dot ops.
223
- if (pipelineWithoutDot && !seenDot) {
224
- for (Operation &op : forOp.getBody ()->without_terminator ()) {
225
- if (!isa<tt::LoadOp, tt::DescriptorLoadOp, tt::DescriptorGatherOp>(op))
226
- dfs (&op, &op, 0 );
227
- }
228
- }
229
-
230
- return loadOpToIndLevel;
231
- }
232
149
};
233
150
234
151
class AssignMMALatencies {
@@ -335,6 +252,94 @@ void assignLatencies(ModuleOp moduleOp, int defaultNumStages) {
335
252
336
253
} // namespace
337
254
255
+ // Create a map from load ops to their indirection level and the
256
+ // final use of the load op (another load op, or a dot op).
257
+ // Indirection level is "0" for the load op directly used by the dot op,
258
+ // "1" for the load op used by the load op used by the dot op, and so on.
259
+ llvm::MapVector<Operation *, std::pair<int , Operation *>>
260
+ loadOpsToIndirectionLevel (scf::ForOp forOp, bool pipelineWithoutDot,
261
+ tt::ModuleAxisInfoAnalysis &axisInfoAnalysis,
262
+ int numStages, bool filterSmall) {
263
+ llvm::MapVector<Operation *, std::pair<int , Operation *>> loadOpToIndLevel;
264
+ DenseSet<Operation *> seen;
265
+ DenseSet<Operation *> excluded;
266
+
267
+ std::function<void (Operation *, Operation *, int )> dfs =
268
+ [&](Operation *op, Operation *finalUser, int distance) {
269
+ if (!seen.insert (op).second || excluded.count (op))
270
+ return ;
271
+ if (isa<tt::LoadOp, tt::DescriptorLoadOp, tt::DescriptorGatherOp>(op)) {
272
+ if (!AssignLoadLatencies::isPipeliningBeneficial (
273
+ op, finalUser, axisInfoAnalysis, filterSmall))
274
+ return ;
275
+ if (loadOpToIndLevel.count (op)) {
276
+ int level = loadOpToIndLevel[op].first ;
277
+ if (level != distance) {
278
+ // If we have multiple uses at different distances, we don't
279
+ // know which one to pick.
280
+ LDBG (" Load " << *op
281
+ << " has multiple uses at different distances:"
282
+ << level << " and " << distance);
283
+ loadOpToIndLevel.erase (op);
284
+ excluded.insert (op);
285
+ return ;
286
+ }
287
+ } else {
288
+ LDBG (" Load " << *op << " considered for pipelining with distance "
289
+ << distance);
290
+ loadOpToIndLevel[op] = {distance, finalUser};
291
+ }
292
+ finalUser = op;
293
+ distance++;
294
+ }
295
+ for (Value operand : getNestedOperands (op)) {
296
+ if (isa<mlir::triton::DotOpInterface>(op)) {
297
+ // Heuristic: only pipeline A and B operands of the dot op.
298
+ if (operand == op->getOperand (2 ))
299
+ continue ;
300
+ }
301
+ Value v = operand;
302
+ Operation *defOp = v.getDefiningOp ();
303
+ if (defOp && defOp->getBlock () == op->getBlock ()) {
304
+ dfs (defOp, finalUser, distance);
305
+ }
306
+ }
307
+ };
308
+
309
+ bool seenDot = false ;
310
+ for (Operation &op : forOp.getBody ()->without_terminator ()) {
311
+ // Arbitrary heuristic. TMEMStoreOp is included to keep logic consistent
312
+ // with legacy code when we weren't hoisting tmem allocas.
313
+ if (!isa<mlir::triton::DotOpInterface, ttng::TMEMStoreOp>(op))
314
+ continue ;
315
+ seenDot = true ;
316
+ seen.clear ();
317
+ dfs (&op, &op, 0 );
318
+ }
319
+
320
+ // If the loop has numStages attribute, also consider pipelining other loads
321
+ // that are not directly used by dot ops.
322
+ if (pipelineWithoutDot && !seenDot) {
323
+ for (Operation &op : forOp.getBody ()->without_terminator ()) {
324
+ if (!isa<tt::LoadOp, tt::DescriptorLoadOp, tt::DescriptorGatherOp>(op))
325
+ dfs (&op, &op, 0 );
326
+ }
327
+ }
328
+
329
+ // We assume loads with different dist are assigned to different stages.
330
+ // If numStages is 2, we will have no stage available for indirect loads
331
+ // with dist >= 1. In general, when dist is equal to numStages - 1, we
332
+ // should not pipeline it.
333
+ for (auto iter = loadOpToIndLevel.begin (); iter != loadOpToIndLevel.end ();) {
334
+ if (iter->second .first >= numStages - 1 )
335
+ iter = loadOpToIndLevel.erase (iter);
336
+ else
337
+ ++iter;
338
+ }
339
+
340
+ return loadOpToIndLevel;
341
+ }
342
+
338
343
// ===----------------------------------------------------------------------===//
339
344
// Pass Definition
340
345
// ===----------------------------------------------------------------------===//
0 commit comments