22#include " intel/include/Dialect/TritonIntelGPU/IR/Dialect.h"
33#include " intel/include/Dialect/TritonIntelGPU/IR/Utils.h"
44#include " intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h"
5+ #include " mlir/IR/Operation.h"
6+ #include " mlir/IR/Value.h"
57#include " mlir/IR/Verifier.h"
68#include " mlir/Support/LLVM.h"
79#include " triton/Dialect/Triton/IR/Dialect.h"
1113#include " triton/Tools/StrUtil.h"
1214#include " llvm/Support/Debug.h"
1315#include " llvm/Support/raw_ostream.h"
16+ #include < variant>
1417
1518#define DEBUG_TYPE " tritonintelgpu-coalesce"
1619#define DBGS () (llvm::dbgs() << " [" DEBUG_TYPE " ]: " )
@@ -150,7 +153,7 @@ struct CoalescePass
150153
151154 static bool filterUser (Operation *op) {
152155 // Yield operations trigger updating the layout of the containing loop
153- // results, so don't skip them.
156+ // results, don't skip them.
154157 if (isa<scf::YieldOp>(op))
155158 return false ;
156159
@@ -168,154 +171,123 @@ struct CoalescePass
168171 return false ;
169172 }
170173
171- // Propagate the \p root block argument operation output layout along the
172- // def-use chain.
173- static void propagateLayout (BlockArgument arg, Attribute layout,
174- IRRewriter &rewriter) {
175- llvm::errs () << " arg: " << arg << " \n " ;
174+ // Propagate the layout to \p root operation's result to the \p forOp loop
175+ // init argument that uses it, and transitively to the operations in the loop
176+ // body that use that argument.
177+ static void propagate (scf::ForOp forOp, Operation *root, Attribute layout,
178+ IRRewriter &rewriter) {
179+ assert (llvm::any_of (root->getUsers (),
180+ [&](Operation *user) { return user == forOp; }) &&
181+ " Expecting the loop to be a user of the root operation" );
182+
183+ for (BlockArgument arg : forOp.getRegionIterArgs ()) {
184+ Value loopArg = forOp.getInitArgs ()[arg.getArgNumber () - 1 ];
185+ for (OpResult res : root->getResults ()) {
186+ if (res != loopArg || !tt::isTensorPointerType (res.getType ()))
187+ continue ;
176188
177- auto users = arg.getUsers ();
178- if (users.empty ()) {
179- llvm::errs () << " arg has no users\n " ;
180- return ;
189+ LDBG (" loopArg: " << loopArg);
190+
191+ // Modify the layout of the loop init argument...
192+ tt::PointerType ptrType = cast<tt::PointerType>(arg.getType ());
193+ auto tensorType = cast<RankedTensorType>(ptrType.getPointeeType ());
194+ arg.setType (tt::PointerType::get (getNewType (tensorType, layout),
195+ ptrType.getAddressSpace ()));
196+
197+ // ... and then propagate it to the operations in the loop.
198+ propagateLayout (arg, layout, rewriter);
199+ }
181200 }
201+ }
202+
203+ // Modify the given loop \p forOp and propagate the result of the enclosing
204+ // loop.
205+ static void propagate (scf::ForOp forOp, Attribute layout,
206+ IRRewriter &rewriter) {
207+ Operation *yieldOp = forOp.getBody ()->getTerminator ();
208+
209+ rewriter.modifyOpInPlace (forOp, [&]() {
210+ for (auto [opType, res] :
211+ llvm::zip (yieldOp->getOperandTypes (), forOp.getResults ())) {
212+ if (opType == res.getType ())
213+ continue ;
214+
215+ assert (tt::isTensorPointerType (res.getType ()) &&
216+ tt::isTensorPointerType (opType) && " Expecting blocked pointers" );
217+ assert (cast<RankedTensorType>(
218+ cast<tt::PointerType>(opType).getPointeeType ())
219+ .getEncoding () == layout &&
220+ " Unexpected layout" );
182221
183- for (Operation *user : users) {
184- llvm::errs () << " arg's user: " << *user << " \n\n " ;
222+ auto resType = cast<tt::PointerType>(res.getType ());
223+ RankedTensorType tensorType = getRankedTensorType (resType);
224+ res.setType (tt::PointerType::get (getNewType (tensorType, layout),
225+ resType.getAddressSpace ()));
226+ }
227+ });
228+
229+ propagateLayout (forOp, layout, rewriter);
230+ }
231+
232+ static void propagateLayout (BlockArgument arg, Attribute layout,
233+ IRRewriter &rewriter) {
234+ LDBG (" arg: " << arg);
235+ for (Operation *user : arg.getUsers ()) {
236+ LDBG (" arg's user: " << *user << " \n " );
185237 if (filterUser (user)) {
186- llvm::errs () << " SKIP\n " ;
187238 continue ;
188239 }
189-
190240 if (auto yieldOp = dyn_cast<scf::YieldOp>(user)) {
191- // Modify and propagate the result of the enclosing loop.
192241 auto forOp = yieldOp->getParentOfType <scf::ForOp>();
193-
194- rewriter.modifyOpInPlace (forOp, [&]() {
195- for (auto [opType, res] :
196- llvm::zip (yieldOp->getOperandTypes (), forOp.getResults ())) {
197- if (opType == res.getType ())
198- continue ;
199-
200- assert (tt::isTensorPointerType (res.getType ()) &&
201- tt::isTensorPointerType (opType) &&
202- " Expecting blocked pointers" );
203- assert (cast<RankedTensorType>(
204- cast<tt::PointerType>(opType).getPointeeType ())
205- .getEncoding () == layout &&
206- " Unexpected layout" );
207-
208- auto resType = cast<tt::PointerType>(res.getType ());
209- RankedTensorType tensorType = getRankedTensorType (resType);
210- res.setType (tt::PointerType::get (getNewType (tensorType, layout),
211- resType.getAddressSpace ()));
212- }
213- });
214-
215- propagateLayout (forOp, layout, rewriter);
242+ propagate (forOp, layout, rewriter);
216243 continue ;
217244 }
218-
219245 changeAndPropagateLayout (user, layout, rewriter);
220246 }
221247 }
222248
223249 static void propagateLayout (Operation *root, Attribute layout,
224250 IRRewriter &rewriter) {
225- assert (root && root ->getNumResults () != 0 &&
251+ assert (root->getNumResults () != 0 &&
226252 " Expecting an operation yielding a result" );
227253
228- llvm::errs () << " root: " << *root << " \n " ;
229- auto users = root->getUsers ();
230- if (users.empty ()) {
231- llvm::errs () << " root has no users\n " ;
232- return ;
233- }
234-
235- for (Operation *user : users) {
236- llvm::errs () << " root's user: " << *user << " \n\n " ;
254+ LDBG (" root: " << *root);
255+ for (Operation *user : root->getUsers ()) {
256+ LDBG (" root's user: " << *user << " \n " );
237257 if (filterUser (user)) {
238- llvm::errs () << " SKIP\n " ;
239258 continue ;
240259 }
241-
242- if (auto yieldOp = dyn_cast<scf::YieldOp>(user)) {
243- // Modify and propagate the result of the enclosing loop.
244- auto forOp = yieldOp->getParentOfType <scf::ForOp>();
245-
246- rewriter.modifyOpInPlace (forOp, [&]() {
247- for (auto [opType, res] :
248- llvm::zip (yieldOp->getOperandTypes (), forOp.getResults ())) {
249- if (opType == res.getType ())
250- continue ;
251-
252- assert (tt::isTensorPointerType (res.getType ()) &&
253- tt::isTensorPointerType (opType) &&
254- " Expecting blocked pointers" );
255- assert (cast<RankedTensorType>(
256- cast<tt::PointerType>(opType).getPointeeType ())
257- .getEncoding () == layout &&
258- " Unexpected layout" );
259-
260- auto resType = cast<tt::PointerType>(res.getType ());
261- RankedTensorType tensorType = getRankedTensorType (resType);
262- res.setType (tt::PointerType::get (getNewType (tensorType, layout),
263- resType.getAddressSpace ()));
264- }
265- });
266-
267- propagateLayout (forOp, layout, rewriter);
260+ if (auto forOp = dyn_cast<scf::ForOp>(user)) {
261+ propagate (forOp, root, layout, rewriter);
268262 continue ;
269263 }
270-
271- if (auto forOp = dyn_cast<scf::ForOp>(user)) {
272- for (BlockArgument arg : forOp.getRegionIterArgs ()) {
273- Value loopArg = forOp.getInitArgs ()[arg.getArgNumber () - 1 ];
274- for (OpResult res : root->getResults ()) {
275- if (res == loopArg && tt::isTensorPointerType (res.getType ())) {
276- llvm::errs () << " arg: " << arg << " \n " ;
277- llvm::errs () << " loopArg: " << loopArg << " \n " ;
278-
279- // Modify the layout of the loop init argument...
280- tt::PointerType ptrType = cast<tt::PointerType>(arg.getType ());
281- auto tensorType =
282- cast<RankedTensorType>(ptrType.getPointeeType ());
283- arg.setType (tt::PointerType::get (getNewType (tensorType, layout),
284- ptrType.getAddressSpace ()));
285-
286- // ... and then propagate it to the operations in the loop.
287- propagateLayout (arg, layout, rewriter);
288- }
289- }
290- }
264+ if (auto yieldOp = dyn_cast<scf::YieldOp>(user)) {
265+ auto forOp = yieldOp->getParentOfType <scf::ForOp>();
266+ propagate (forOp, layout, rewriter);
291267 continue ;
292268 }
293-
294269 changeAndPropagateLayout (user, layout, rewriter);
295270 }
296271 }
297272
298- // TODO: change the implementation to handle only operation yielding one
299- // result?
300- // Change the \p layout of the \p op result(s) and propagate the new
301- // result type to its users.
273+ // Change the \p layout of the \p op result and propagate the new result type
274+ // to its users.
302275 static void changeAndPropagateLayout (Operation *op, Attribute layout,
303276 IRRewriter &rewriter) {
304- assert (op && op->getNumResults () != 0 &&
277+ assert (op && op->getNumResults () == 1 &&
305278 " Expecting operation yielding a result" );
306279
307280 rewriter.modifyOpInPlace (op, [&]() {
308- for (Value res : op->getResults ()) {
309- if (!tt::isTensorPointerType (res.getType ()))
310- continue ;
311-
312- auto ptrType = cast<tt::PointerType>(res.getType ());
313- auto tensorType = cast<RankedTensorType>(ptrType.getPointeeType ());
314- res.setType (tt::PointerType::get (getNewType (tensorType, layout),
315- ptrType.getAddressSpace ()));
316- }
281+ Value res = op->getOpResult (0 );
282+ assert (tt::isTensorPointerType (res.getType ()) &&
283+ " Expecting a block pointer" );
284+
285+ auto ptrType = cast<tt::PointerType>(res.getType ());
286+ auto tensorType = cast<RankedTensorType>(ptrType.getPointeeType ());
287+ res.setType (tt::PointerType::get (getNewType (tensorType, layout),
288+ ptrType.getAddressSpace ()));
317289 });
318- llvm::errs () << " Coalesced op: " << *op << " \n " ;
290+ LDBG ( " Coalesced op: " << *op) ;
319291
320292 propagateLayout (op, layout, rewriter);
321293 }
@@ -400,22 +372,14 @@ struct CoalescePass
400372 if (!refTensorType || !refTensorType.getEncoding ())
401373 return ;
402374
403- // static int n = 0;
404- // if (tt::isTensorPointerType(ptr.getType()))
405- // n++;
406-
407- // if (n != 2)
408- // return;
409-
410375 int numWarps = ttg::TritonGPUDialect::getNumWarps (moduleOp);
411376 int threadsPerWarp = ttg::TritonGPUDialect::getThreadsPerWarp (moduleOp);
412377 setCoalescedEncoding (axisInfoAnalysis, curr, numWarps, threadsPerWarp,
413378 layoutMap);
414379 });
415380
416381 LLVM_DEBUG ({
417- DBGS () << " layoutMap:"
418- << " \n " ;
382+ DBGS () << " layoutMap:" << " \n " ;
419383 for (auto [op, encoding] : layoutMap) {
420384 DBGS () << " op: " << *op << " \n " ;
421385 DBGS () << " encoding: " << encoding << " \n " ;
0 commit comments