@@ -28,6 +28,13 @@ namespace ttgi = mlir::triton::gpu::intel;
2828
2929namespace {
3030
31+ RankedTensorType getRankedTensorType (Type ptrTy) {
32+ return tt::isTensorPointerType (ptrTy)
33+ ? cast<RankedTensorType>(
34+ cast<tt::PointerType>(ptrTy).getPointeeType ())
35+ : dyn_cast<RankedTensorType>(ptrTy);
36+ }
37+
3138struct CoalescePass
3239 : public ttgi::impl::TritonIntelGPUCoalesceBase<CoalescePass> {
3340private:
@@ -49,12 +56,7 @@ struct CoalescePass
4956 SmallVector<unsigned > order = argSort (contiguity);
5057 LDBG (" order=[" << triton::join (order, " , " ) << " ]" );
5158
52- RankedTensorType refTensorType =
53- tt::isTensorPointerType (ptr.getType ())
54- ? cast<RankedTensorType>(
55- cast<tt::PointerType>(ptr.getType ()).getPointeeType ())
56- : cast<RankedTensorType>(ptr.getType ());
57-
59+ RankedTensorType refTensorType = getRankedTensorType (ptr.getType ());
5860 auto matchesShape = [&refTensorType](const Value &val) {
5961 auto rttType = dyn_cast<RankedTensorType>(val.getType ());
6062 return rttType && rttType.getShape () == refTensorType.getShape ();
@@ -197,7 +199,7 @@ struct CoalescePass
197199 " Unexpected layout" );
198200
199201 auto resType = cast<tt::PointerType>(res.getType ());
200- auto tensorType = cast<RankedTensorType> (resType. getPointeeType () );
202+ RankedTensorType tensorType = getRankedTensorType (resType);
201203 res.setType (tt::PointerType::get (getNewType (tensorType, layout),
202204 resType.getAddressSpace ()));
203205 }
@@ -243,7 +245,7 @@ struct CoalescePass
243245 " Unexpected layout" );
244246
245247 auto resType = cast<tt::PointerType>(res.getType ());
246- auto tensorType = cast<RankedTensorType> (resType. getPointeeType () );
248+ RankedTensorType tensorType = getRankedTensorType (resType);
247249 res.setType (tt::PointerType::get (getNewType (tensorType, layout),
248250 resType.getAddressSpace ()));
249251 }
@@ -281,8 +283,10 @@ struct CoalescePass
281283 }
282284 }
283285
284- // Change the \p layout of the \p op result(s) and propagate the new result
285- // type to its users.
286+ // TODO: change the implementation to handle only operation yielding one
287+ // result?
288+ // Change the \p layout of the \p op result(s) and propagate the new
289+ // result type to its users.
286290 static void changeAndPropagateLayout (Operation *op, Attribute layout,
287291 IRRewriter &rewriter) {
288292 assert (op && op->getNumResults () != 0 &&
@@ -293,10 +297,6 @@ struct CoalescePass
293297 if (!tt::isTensorPointerType (res.getType ()))
294298 continue ;
295299
296- // Problem: if the operation is a for loop we cannot modify the layout
297- // of all the tensor ptr results, we need to modify only the one used by
298- // the yield operation.
299-
300300 auto ptrType = cast<tt::PointerType>(res.getType ());
301301 auto tensorType = cast<RankedTensorType>(ptrType.getPointeeType ());
302302 res.setType (tt::PointerType::get (getNewType (tensorType, layout),
@@ -382,11 +382,7 @@ struct CoalescePass
382382 if (!ptr)
383383 return ;
384384
385- RankedTensorType refTensorType =
386- tt::isTensorPointerType (ptr.getType ())
387- ? cast<RankedTensorType>(
388- cast<tt::PointerType>(ptr.getType ()).getPointeeType ())
389- : dyn_cast<RankedTensorType>(ptr.getType ());
385+ RankedTensorType refTensorType = getRankedTensorType (ptr.getType ());
390386 if (!refTensorType || !refTensorType.getEncoding ())
391387 return ;
392388
0 commit comments