11#include " intel/include/Analysis/AxisInfo.h"
22#include " intel/include/Dialect/TritonIntelGPU/IR/Utils.h"
33#include " intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h"
4+ #include " intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h"
45#include " mlir/IR/Operation.h"
56#include " mlir/IR/Value.h"
67#include " mlir/IR/Verifier.h"
@@ -27,13 +28,6 @@ namespace ttgi = mlir::triton::gpu::intel;
2728
2829namespace {
2930
30- RankedTensorType getRankedTensorType (Type ptrTy) {
31- return tt::isTensorPointerType (ptrTy)
32- ? cast<RankedTensorType>(
33- cast<tt::PointerType>(ptrTy).getPointeeType ())
34- : dyn_cast<RankedTensorType>(ptrTy);
35- }
36-
3731struct CoalescePass
3832 : public ttgi::impl::TritonIntelGPUCoalesceBase<CoalescePass> {
3933private:
@@ -55,7 +49,7 @@ struct CoalescePass
5549 SmallVector<unsigned > order = argSort (contiguity);
5650 LDBG (" order=[" << triton::join (order, " , " ) << " ]" );
5751
58- RankedTensorType refTensorType = getRankedTensorType (ptr.getType ());
52+ RankedTensorType refTensorType = ttgi:: getRankedTensorType (ptr.getType ());
5953 auto matchesShape = [&refTensorType](const Value &val) {
6054 auto rttType = dyn_cast<RankedTensorType>(val.getType ());
6155 return rttType && rttType.getShape () == refTensorType.getShape ();
@@ -279,7 +273,7 @@ struct CoalescePass
279273 " Unexpected layout" );
280274
281275 auto resType = cast<tt::PointerType>(res.getType ());
282- RankedTensorType tensorType = getRankedTensorType (resType);
276+ RankedTensorType tensorType = ttgi:: getRankedTensorType (resType);
283277 res.setType (tt::PointerType::get (getNewType (tensorType, layout),
284278 resType.getAddressSpace ()));
285279 }
@@ -362,7 +356,7 @@ struct CoalescePass
362356 if (!ptr)
363357 return ;
364358
365- RankedTensorType refTensorType = getRankedTensorType (ptr.getType ());
359+ RankedTensorType refTensorType = ttgi:: getRankedTensorType (ptr.getType ());
366360 if (!refTensorType || !refTensorType.getEncoding ())
367361 return ;
368362
@@ -373,8 +367,7 @@ struct CoalescePass
373367 });
374368
375369 LLVM_DEBUG ({
376- DBGS () << " \n layoutMap:"
377- << " \n " ;
370+ DBGS () << " \n layoutMap:" << " \n " ;
378371 for (auto [op, encoding] : layoutMap) {
379372 DBGS () << " op: " << *op << " \n " ;
380373 DBGS () << " encoding: " << encoding << " \n\n " ;
0 commit comments