7
7
#include " mlir/IR/Operation.h"
8
8
#include " mlir/IR/Value.h"
9
9
#include " mlir/IR/Verifier.h"
10
+ #include " mlir/Interfaces/LoopLikeInterface.h"
10
11
#include " mlir/Support/LLVM.h"
11
12
#include " triton/Dialect/Triton/IR/Types.h"
12
13
#include " triton/Dialect/Triton/IR/Utility.h"
13
14
#include " triton/Dialect/TritonGPU/Transforms/Utility.h"
14
15
#include " triton/Tools/StrUtil.h"
16
+ #include " llvm/ADT/STLExtras.h"
15
17
#include " llvm/Support/Debug.h"
16
18
#include " llvm/Support/ErrorHandling.h"
17
19
#include " llvm/Support/raw_ostream.h"
@@ -154,55 +156,58 @@ struct CoalescePass
154
156
return false ;
155
157
}
156
158
157
- // Change the \p layout of the \p op result and propagate the new result type
158
- // to its users.
159
- void changeAndPropagateLayout (Operation *op, Attribute layout,
159
+ // Change the \p layout of the \p op's result \p opRes and propagate the new
160
+ // result type to its users.
161
+ void changeAndPropagateLayout (Operation *op, Value opRes, Attribute layout,
160
162
IRRewriter &rewriter) const {
161
163
assert (op && op->getNumResults () != 0 &&
162
164
" Expecting operation yielding results" );
163
165
164
166
LLVM_DEBUG ({
165
167
llvm::dbgs () << " [" DEBUG_TYPE " ]: " << " ChangeAndPropagateLayout for: " ;
166
168
op->dumpPretty ();
169
+ llvm::dbgs () << " opRes: " ;
170
+ opRes.printAsOperand (llvm::dbgs (), {});
171
+ llvm::dbgs () << " \n " ;
167
172
});
168
173
169
174
rewriter.modifyOpInPlace (op, [&]() {
170
- for (Value res : op->getResults ()) {
171
- if (!tt::isTensorPointerType (res.getType ()))
172
- continue ;
173
-
174
- auto ptrType = cast<tt::PointerType>(res.getType ());
175
- auto tensorType = cast<RankedTensorType>(ptrType.getPointeeType ());
176
- res.setType (tt::PointerType::get (getNewType (tensorType, layout),
175
+ assert (tt::isTensorPointerType (opRes.getType ()));
176
+ auto ptrType = cast<tt::PointerType>(opRes.getType ());
177
+ auto tensorType = cast<RankedTensorType>(ptrType.getPointeeType ());
178
+ opRes.setType (tt::PointerType::get (getNewType (tensorType, layout),
177
179
ptrType.getAddressSpace ()));
178
- }
179
180
});
180
181
181
182
LLVM_DEBUG ({
182
183
llvm::dbgs () << " [" DEBUG_TYPE " ]: Coalesced op: " ;
183
184
op->dumpPretty ();
184
185
});
185
186
186
- propagateLayout (op, layout, rewriter);
187
+ for (OpResult res : op->getResults ())
188
+ if (res == opRes)
189
+ propagateLayout (op, res, layout, rewriter);
187
190
}
188
191
189
192
// Propagate the layout of the \p root operation's result to its users.
190
- void propagateLayout (Operation *root , Attribute layout,
193
+ void propagateLayout (Operation *op, Value opRes , Attribute layout,
191
194
IRRewriter &rewriter) const {
192
- assert (root ->getNumResults () != 0 &&
195
+ assert (op && op ->getNumResults () != 0 &&
193
196
" Expecting an operation yielding a result" );
194
-
195
- auto mod = root->getParentOfType <ModuleOp>();
197
+ assert (opRes &&
198
+ llvm::any_of (op->getResults (),
199
+ [&](OpResult res) { return res == opRes; }) &&
200
+ " Expecting operation to yield 'opRes'" );
196
201
197
202
LLVM_DEBUG ({
198
- if (!root-> getUsers ().empty ()) {
203
+ if (!opRes. getUsers ().empty ()) {
199
204
llvm::dbgs () << " [" DEBUG_TYPE " ]: "
200
- << " Propagate layout to operations using: " ;
201
- root-> dumpPretty () ;
205
+ << " Propagate layout to operations using: " << opRes
206
+ << " \n " ;
202
207
}
203
208
});
204
209
205
- for (Operation *user : root-> getUsers ()) {
210
+ for (Operation *user : opRes. getUsers ()) {
206
211
if (filterUser (user))
207
212
continue ;
208
213
@@ -212,50 +217,71 @@ struct CoalescePass
212
217
});
213
218
214
219
if (auto forOp = dyn_cast<scf::ForOp>(user)) {
215
- propagateLayoutToArgsAndBody (forOp, root , layout, rewriter);
220
+ propagateLayoutToArgsAndBody (forOp, opRes , layout, rewriter);
216
221
continue ;
217
222
}
218
223
if (auto whileOp = dyn_cast<scf::WhileOp>(user)) {
219
- propagateLayoutToArgsAndBody (whileOp, root , layout, rewriter);
224
+ propagateLayoutToArgsAndBody (whileOp, opRes , layout, rewriter);
220
225
continue ;
221
226
}
222
-
223
227
if (auto yieldOp = dyn_cast<scf::YieldOp>(user)) {
224
- if (auto forOp = yieldOp->getParentOfType <scf::ForOp>())
225
- propagateLayoutToLoopResults (forOp, layout, rewriter);
226
- if (auto whileOp = yieldOp->getParentOfType <scf::WhileOp>())
227
- propagateLayoutToLoopResults (whileOp, layout, rewriter);
228
- continue ;
228
+ if (auto loopOp = yieldOp->getParentOfType <LoopLikeOpInterface>()) {
229
+ for (OpOperand &operand : yieldOp->getOpOperands ())
230
+ if (operand.get () == opRes)
231
+ propagateLayoutToLoopResults (loopOp, operand.getOperandNumber (),
232
+ layout, rewriter);
233
+ continue ;
234
+ }
229
235
}
230
236
231
237
LLVM_DEBUG ({
232
238
llvm::dbgs () << " [" DEBUG_TYPE " ]: After propagating layout:\n " ;
233
- mod ->dumpPretty ();
239
+ op-> getParentOfType <ModuleOp>() ->dumpPretty ();
234
240
});
235
241
236
- changeAndPropagateLayout (user, layout, rewriter);
242
+ for (OpResult res : user->getResults ())
243
+ changeAndPropagateLayout (user, res, layout, rewriter);
237
244
}
238
245
}
239
246
240
247
// Propagate the layout of the \p arg block argument to its users.
241
248
void propagateLayout (BlockArgument arg, Attribute layout,
242
249
IRRewriter &rewriter) const {
250
+ LLVM_DEBUG ({
251
+ if (!arg.getUsers ().empty ()) {
252
+ llvm::dbgs () << " [" DEBUG_TYPE " ]: "
253
+ << " Propagate layout to operations using: " ;
254
+ arg.printAsOperand (llvm::dbgs (), {});
255
+ llvm::dbgs () << " \n " ;
256
+ }
257
+ });
258
+
243
259
for (Operation *user : arg.getUsers ()) {
244
260
if (filterUser (user))
245
261
continue ;
246
262
247
263
LLVM_DEBUG ({
248
- llvm::dbgs () << " [" DEBUG_TYPE " ]: " << " arg's user: " ;
264
+ llvm::dbgs () << " [" DEBUG_TYPE " ]: " << " user: " ;
249
265
user->dumpPretty ();
250
266
});
251
267
252
- if (auto yieldOp = dyn_cast<scf::YieldOp>(user)) {
253
- if (auto forOp = yieldOp->getParentOfType <scf::ForOp>())
254
- propagateLayoutToLoopResults (forOp, layout, rewriter);
255
- if (auto whileOp = yieldOp->getParentOfType <scf::WhileOp>())
256
- propagateLayoutToLoopResults (whileOp, layout, rewriter);
268
+ if (auto forOp = dyn_cast<scf::ForOp>(user)) {
269
+ propagateLayoutToArgsAndBody (forOp, arg, layout, rewriter);
257
270
continue ;
258
271
}
272
+ if (auto whileOp = dyn_cast<scf::WhileOp>(user)) {
273
+ propagateLayoutToArgsAndBody (whileOp, arg, layout, rewriter);
274
+ continue ;
275
+ }
276
+ if (auto yieldOp = dyn_cast<scf::YieldOp>(user)) {
277
+ if (auto loopOp = yieldOp->getParentOfType <LoopLikeOpInterface>()) {
278
+ for (OpOperand &operand : yieldOp->getOpOperands ())
279
+ if (operand.get () == arg)
280
+ propagateLayoutToLoopResults (loopOp, operand.getOperandNumber (),
281
+ layout, rewriter);
282
+ continue ;
283
+ }
284
+ }
259
285
if (auto condOp = dyn_cast<scf::ConditionOp>(user)) {
260
286
if (auto whileOp = condOp->getParentOfType <scf::WhileOp>()) {
261
287
// Propagate layout to "after" region arguments.
@@ -284,7 +310,8 @@ struct CoalescePass
284
310
continue ;
285
311
}
286
312
287
- changeAndPropagateLayout (user, layout, rewriter);
313
+ for (OpResult res : user->getResults ())
314
+ changeAndPropagateLayout (user, res, layout, rewriter);
288
315
}
289
316
290
317
LLVM_DEBUG ({
@@ -300,74 +327,48 @@ struct CoalescePass
300
327
// loop body that use that argument.
301
328
template <typename OpType, typename = std::enable_if_t <llvm::is_one_of<
302
329
OpType, scf::ForOp, scf::WhileOp>::value>>
303
- void propagateLayoutToArgsAndBody (OpType loopOp, Operation *root ,
330
+ void propagateLayoutToArgsAndBody (OpType loopOp, Value opRes ,
304
331
Attribute layout,
305
332
IRRewriter &rewriter) const {
306
- assert (llvm::any_of (root->getUsers (),
307
- [&](Operation *user) { return user == loopOp; }) &&
308
- " Expecting the loop to be a user of the root operation" );
309
-
310
- for (BlockArgument arg : loopOp.getRegionIterArgs ()) {
311
- Value loopArg;
312
- if constexpr (std::is_same<OpType, scf::ForOp>::value)
313
- loopArg = loopOp.getInitArgs ()[arg.getArgNumber () - 1 ];
314
- if constexpr (std::is_same<OpType, scf::WhileOp>::value)
315
- loopArg = loopOp.getInits ()[arg.getArgNumber ()];
316
-
317
- for (OpResult res : root->getResults ()) {
318
- if (res != loopArg || !tt::isTensorPointerType (res.getType ()))
319
- continue ;
320
- // Modify the layout of the loop init argument...
321
- tt::PointerType ptrType = cast<tt::PointerType>(arg.getType ());
322
- auto tensorType = cast<RankedTensorType>(ptrType.getPointeeType ());
323
- arg.setType (tt::PointerType::get (getNewType (tensorType, layout),
324
- ptrType.getAddressSpace ()));
325
- LLVM_DEBUG ({
326
- llvm::dbgs () << " [" DEBUG_TYPE " ]: " << " Propagated layout to: " ;
327
- arg.printAsOperand (llvm::dbgs (), {});
328
- llvm::dbgs () << " \n " ;
329
- });
330
-
331
- // ... and then propagate it to the operations in the loop.
332
- propagateLayout (arg, layout, rewriter);
333
- }
333
+ for (auto [initArg, arg] :
334
+ llvm::zip (loopOp.getInitsMutable (), loopOp.getRegionIterArgs ())) {
335
+ if (initArg.get () != opRes)
336
+ continue ;
337
+
338
+ // Modify the layout of the loop init argument...
339
+ auto ptrType = cast<tt::PointerType>(arg.getType ());
340
+ auto tensorType = cast<RankedTensorType>(ptrType.getPointeeType ());
341
+ arg.setType (tt::PointerType::get (getNewType (tensorType, layout),
342
+ ptrType.getAddressSpace ()));
343
+
344
+ LLVM_DEBUG ({
345
+ llvm::dbgs () << " [" DEBUG_TYPE " ]: " << " Propagated layout to: " ;
346
+ arg.printAsOperand (llvm::dbgs (), {});
347
+ llvm::dbgs () << " \n " ;
348
+ });
349
+
350
+ // ... and then propagate it to the operations in the loop.
351
+ propagateLayout (arg, layout, rewriter);
334
352
}
335
353
}
336
354
337
- // Modify the given loop \p loopOpt and propagate its results to their users.
338
- template < typename OpType, typename = std:: enable_if_t <llvm::is_one_of<
339
- OpType, scf::ForOp, scf::WhileOp>::value>>
340
- void propagateLayoutToLoopResults (OpType loopOp, Attribute layout,
355
+ // Modify the \p layout to the loop's operand identified by \p resNum, and
356
+ // propagate the modified loop results to its users.
357
+ void propagateLayoutToLoopResults (LoopLikeOpInterface loopOp, unsigned resNum,
358
+ Attribute layout,
341
359
IRRewriter &rewriter) const {
342
- Operation *yieldOp = nullptr ;
343
- if constexpr (std::is_same<OpType, scf::ForOp>::value)
344
- yieldOp = loopOp.getBody ()->getTerminator ();
345
- if constexpr (std::is_same<OpType, scf::WhileOp>::value)
346
- yieldOp = loopOp.getYieldOp ();
347
-
360
+ Value loopRes = loopOp->getResult (resNum);
348
361
rewriter.modifyOpInPlace (loopOp, [&]() {
349
- for (auto [yieldOperandType, res] :
350
- llvm::zip (yieldOp->getOperandTypes (), loopOp.getResults ())) {
351
- Type resType = res.getType ();
352
- if (yieldOperandType == resType)
353
- continue ;
354
-
355
- assert (tt::isTensorPointerType (resType) &&
356
- tt::isTensorPointerType (yieldOperandType) &&
357
- " Expecting blocked pointers" );
358
- assert (cast<RankedTensorType>(
359
- cast<tt::PointerType>(yieldOperandType).getPointeeType ())
360
- .getEncoding () == layout &&
361
- " Unexpected layout" );
362
-
363
- auto ptrType = cast<tt::PointerType>(res.getType ());
364
- RankedTensorType tensorType = ttgi::getRankedTensorType (resType);
365
- res.setType (tt::PointerType::get (getNewType (tensorType, layout),
366
- ptrType.getAddressSpace ()));
367
- }
362
+ assert (tt::isTensorPointerType (loopRes.getType ()) &&
363
+ " Expecting blocked pointers" );
364
+ Type resType = loopRes.getType ();
365
+ auto ptrType = cast<tt::PointerType>(resType);
366
+ RankedTensorType tensorType = ttgi::getRankedTensorType (resType);
367
+ loopRes.setType (tt::PointerType::get (getNewType (tensorType, layout),
368
+ ptrType.getAddressSpace ()));
368
369
});
369
370
370
- propagateLayout (loopOp, layout, rewriter);
371
+ propagateLayout (loopOp, loopRes, layout, rewriter);
371
372
}
372
373
373
374
void coalesceOp (Attribute encoding, Operation *op) {
@@ -404,7 +405,8 @@ struct CoalescePass
404
405
}
405
406
406
407
IRRewriter rewriter (builder);
407
- changeAndPropagateLayout (*defOp, encoding, rewriter);
408
+ changeAndPropagateLayout (*defOp, defOp->getResult (), encoding,
409
+ rewriter);
408
410
newArgs.push_back (operand);
409
411
}
410
412
}
0 commit comments