@@ -171,11 +171,77 @@ struct CoalescePass
171171 return false ;
172172 }
173173
174- // Propagate the layout to \p root operation's result to the \p forOp loop
174+ // Change the \p layout of the \p op result and propagate the new result type
175+ // to its users.
176+ void changeAndPropagateLayout (Operation *op, Attribute layout,
177+ IRRewriter &rewriter) const {
178+ assert (op && op->getNumResults () == 1 &&
179+ " Expecting operation yielding a result" );
180+
181+ rewriter.modifyOpInPlace (op, [&]() {
182+ Value res = op->getOpResult (0 );
183+ assert (tt::isTensorPointerType (res.getType ()) &&
184+ " Expecting a block pointer" );
185+
186+ auto ptrType = cast<tt::PointerType>(res.getType ());
187+ auto tensorType = cast<RankedTensorType>(ptrType.getPointeeType ());
188+ res.setType (tt::PointerType::get (getNewType (tensorType, layout),
189+ ptrType.getAddressSpace ()));
190+ });
191+ LDBG (" Coalesced op: " << *op);
192+
193+ propagateLayout (op, layout, rewriter);
194+ }
195+
196+ // Propagate the layout of the \p root operation's result to its users.
197+ void propagateLayout (Operation *root, Attribute layout,
198+ IRRewriter &rewriter) const {
199+ assert (root->getNumResults () != 0 &&
200+ " Expecting an operation yielding a result" );
201+
202+ LDBG (" root: " << *root);
203+ for (Operation *user : root->getUsers ()) {
204+ if (filterUser (user))
205+ continue ;
206+
207+ LDBG (" root's user: " << *user << " \n " );
208+ if (auto forOp = dyn_cast<scf::ForOp>(user)) {
209+ propagateLayoutToArgsAndBody (forOp, root, layout, rewriter);
210+ continue ;
211+ }
212+ if (auto yieldOp = dyn_cast<scf::YieldOp>(user)) {
213+ auto forOp = yieldOp->getParentOfType <scf::ForOp>();
214+ propagateLayoutToLoopResults (forOp, layout, rewriter);
215+ continue ;
216+ }
217+ changeAndPropagateLayout (user, layout, rewriter);
218+ }
219+ }
220+
221+ // Propagate the layout of the \p arg block argument to its users.
222+ void propagateLayout (BlockArgument arg, Attribute layout,
223+ IRRewriter &rewriter) const {
224+ LDBG (" arg: " << arg);
225+ for (Operation *user : arg.getUsers ()) {
226+ if (filterUser (user))
227+ continue ;
228+
229+ LDBG (" arg's user: " << *user << " \n " );
230+ if (auto yieldOp = dyn_cast<scf::YieldOp>(user)) {
231+ auto forOp = yieldOp->getParentOfType <scf::ForOp>();
232+ propagateLayoutToLoopResults (forOp, layout, rewriter);
233+ continue ;
234+ }
235+ changeAndPropagateLayout (user, layout, rewriter);
236+ }
237+ }
238+
239+ // Propagate the layout of the \p root operation's result to the \p forOp loop
175240 // init argument that uses it, and transitively to the operations in the loop
176241 // body that use that argument.
177- static void propagate (scf::ForOp forOp, Operation *root, Attribute layout,
178- IRRewriter &rewriter) {
242+ void propagateLayoutToArgsAndBody (scf::ForOp forOp, Operation *root,
243+ Attribute layout,
244+ IRRewriter &rewriter) const {
179245 assert (llvm::any_of (root->getUsers (),
180246 [&](Operation *user) { return user == forOp; }) &&
181247 " Expecting the loop to be a user of the root operation" );
@@ -202,8 +268,8 @@ struct CoalescePass
202268
203269 // Modify the given loop \p forOp and propagate the result of the enclosing
204270 // loop.
205- static void propagate (scf::ForOp forOp, Attribute layout,
206- IRRewriter &rewriter) {
271+ void propagateLayoutToLoopResults (scf::ForOp forOp, Attribute layout,
272+ IRRewriter &rewriter) const {
207273 Operation *yieldOp = forOp.getBody ()->getTerminator ();
208274
209275 rewriter.modifyOpInPlace (forOp, [&]() {
@@ -229,69 +295,6 @@ struct CoalescePass
229295 propagateLayout (forOp, layout, rewriter);
230296 }
231297
232- static void propagateLayout (BlockArgument arg, Attribute layout,
233- IRRewriter &rewriter) {
234- LDBG (" arg: " << arg);
235- for (Operation *user : arg.getUsers ()) {
236- LDBG (" arg's user: " << *user << " \n " );
237- if (filterUser (user)) {
238- continue ;
239- }
240- if (auto yieldOp = dyn_cast<scf::YieldOp>(user)) {
241- auto forOp = yieldOp->getParentOfType <scf::ForOp>();
242- propagate (forOp, layout, rewriter);
243- continue ;
244- }
245- changeAndPropagateLayout (user, layout, rewriter);
246- }
247- }
248-
249- static void propagateLayout (Operation *root, Attribute layout,
250- IRRewriter &rewriter) {
251- assert (root->getNumResults () != 0 &&
252- " Expecting an operation yielding a result" );
253-
254- LDBG (" root: " << *root);
255- for (Operation *user : root->getUsers ()) {
256- LDBG (" root's user: " << *user << " \n " );
257- if (filterUser (user)) {
258- continue ;
259- }
260- if (auto forOp = dyn_cast<scf::ForOp>(user)) {
261- propagate (forOp, root, layout, rewriter);
262- continue ;
263- }
264- if (auto yieldOp = dyn_cast<scf::YieldOp>(user)) {
265- auto forOp = yieldOp->getParentOfType <scf::ForOp>();
266- propagate (forOp, layout, rewriter);
267- continue ;
268- }
269- changeAndPropagateLayout (user, layout, rewriter);
270- }
271- }
272-
273- // Change the \p layout of the \p op result and propagate the new result type
274- // to its users.
275- static void changeAndPropagateLayout (Operation *op, Attribute layout,
276- IRRewriter &rewriter) {
277- assert (op && op->getNumResults () == 1 &&
278- " Expecting operation yielding a result" );
279-
280- rewriter.modifyOpInPlace (op, [&]() {
281- Value res = op->getOpResult (0 );
282- assert (tt::isTensorPointerType (res.getType ()) &&
283- " Expecting a block pointer" );
284-
285- auto ptrType = cast<tt::PointerType>(res.getType ());
286- auto tensorType = cast<RankedTensorType>(ptrType.getPointeeType ());
287- res.setType (tt::PointerType::get (getNewType (tensorType, layout),
288- ptrType.getAddressSpace ()));
289- });
290- LDBG (" Coalesced op: " << *op);
291-
292- propagateLayout (op, layout, rewriter);
293- }
294-
295298 void coalesceOp (Attribute encoding, Operation *op) {
296299 LDBG (" Coalescing op: " << *op);
297300
@@ -316,8 +319,7 @@ struct CoalescePass
316319 " Expecting operand to have blocked pointer type" );
317320 auto defOp = findDefiningMakeTensorPtrOp (operand);
318321 assert (defOp && " Expected a make_tensor_ptr operation" );
319-
320- llvm::errs () << " Found make_tensor_ptr definition: " << *defOp << " \n " ;
322+ LDBG (" Found make_tensor_ptr definition: " << *defOp);
321323 changeAndPropagateLayout (*defOp, encoding, rewriter);
322324 newArgs.push_back (operand);
323325 }
@@ -326,8 +328,7 @@ struct CoalescePass
326328 // Convert output types
327329 SmallVector<Type, 4 > newTypes;
328330 for (auto t : op->getResultTypes ()) {
329- bool isAsync = isa<ttg::AsyncCopyGlobalToLocalOp>(op);
330- assert (!isAsync &&
331+ assert (!isa<ttg::AsyncCopyGlobalToLocalOp>(op) &&
331332 " AsyncCopyGlobalToLocalOp not supported for Intel GPU" );
332333 newTypes.push_back (getNewType (cast<RankedTensorType>(t), encoding));
333334 }
@@ -379,7 +380,8 @@ struct CoalescePass
379380 });
380381
381382 LLVM_DEBUG ({
382- DBGS () << " layoutMap:" << " \n " ;
383+ DBGS () << " layoutMap:"
384+ << " \n " ;
383385 for (auto [op, encoding] : layoutMap) {
384386 DBGS () << " op: " << *op << " \n " ;
385387 DBGS () << " encoding: " << encoding << " \n " ;
@@ -398,20 +400,10 @@ struct CoalescePass
398400 coalesceOp (layout, op);
399401 }
400402
401- if (failed (verify (moduleOp))) {
402- llvm::errs () << " Module verification failed.\n " ;
403- llvm::errs () << " mod: " << moduleOp << " \n " ;
404- for (Operation &op1 : moduleOp.getOps ()) {
405- if (isa<tt::FuncOp>(op1)) {
406- for (Operation &op2 : cast<tt::FuncOp>(op1).getOps ()) {
407- if (failed (verify (&op2))) {
408- llvm::errs () << " op2: " << op2 << " \n " ;
409- llvm::errs () << " Operation verification failed.\n " ;
410- assert (false );
411- }
412- }
413- }
414- }
403+ // Verify the module's functions after the transformation.
404+ for (auto op : moduleOp.getOps <tt::FuncOp>()) {
405+ for (Operation &op1 : op.getOps ())
406+ assert (succeeded (verify (&op1)));
415407 }
416408 }
417409};
0 commit comments