@@ -19,6 +19,12 @@ using namespace mlir::xegpu;
1919
2020namespace {
2121
22+
23+ #define DEBUG_TYPE " test-xegpu-unroll"
24+ #define DBGS () (llvm::dbgs() << " [" DEBUG_TYPE " ]: " )
25+ #define LDBG (X ) LLVM_DEBUG(DBGS() << X << " \n " )
26+
27+
2228struct TestXeGPUUnrollingPatterns
2329 : public PassWrapper<TestXeGPUUnrollingPatterns,
2430 OperationPass<gpu::GPUModuleOp>> {
@@ -48,33 +54,21 @@ struct TestXeGPUUnrollingPatterns
4854 options.setNativeShapeFn (
4955 [&](Operation *op) -> std::optional<SmallVector<int64_t >> {
5056 if (isa<xegpu::CreateNdDescOp, xegpu::UpdateNdOffsetOp,
51- xegpu::PrefetchNdOp, xegpu::LoadNdOp, xegpu::StoreNdOp>(op)) {
57+ xegpu::PrefetchNdOp, xegpu::LoadNdOp, xegpu::StoreNdOp,
58+ xegpu::CreateDescOp, xegpu::UpdateOffsetOp, xegpu::PrefetchOp,
59+ xegpu::LoadGatherOp, xegpu::StoreScatterOp>(op)) {
5260 xegpu::TensorDescType tdescTy;
5361 if (auto createNdOp = dyn_cast<xegpu::CreateNdDescOp>(op)) {
5462 tdescTy = createNdOp.getType ();
55- } else if (auto updateNdOp =
56- dyn_cast<xegpu::UpdateNdOffsetOp>(op)) {
63+ } else if (auto updateNdOp = dyn_cast<xegpu::UpdateNdOffsetOp>(op)) {
5764 tdescTy = updateNdOp.getTensorDescType ();
5865 } else if (auto prefetchNdOp = dyn_cast<xegpu::PrefetchNdOp>(op)) {
5966 tdescTy = prefetchNdOp.getTensorDescType ();
6067 } else if (auto loadNdOp = dyn_cast<xegpu::LoadNdOp>(op)) {
6168 tdescTy = loadNdOp.getTensorDescType ();
6269 } else if (auto storeNdOp = dyn_cast<xegpu::StoreNdOp>(op)) {
6370 tdescTy = storeNdOp.getTensorDescType ();
64- }
65-
66- if (auto layout = tdescTy.getLayoutAttr ()) {
67- auto inst_data = layout.getInstData ();
68- if (inst_data && layout.isSgLayout ())
69- return SmallVector<int64_t >(inst_data.asArrayRef ().begin (),
70- inst_data.asArrayRef ().end ());
71- }
72- }
73-
74- if (isa<xegpu::CreateDescOp, xegpu::UpdateOffsetOp, xegpu::PrefetchOp,
75- xegpu::LoadGatherOp, xegpu::StoreScatterOp>(op)) {
76- xegpu::TensorDescType tdescTy;
77- if (auto createOp = dyn_cast<xegpu::CreateDescOp>(op)) {
71+ } else if (auto createOp = dyn_cast<xegpu::CreateDescOp>(op)) {
7872 tdescTy = createOp.getType ();
7973 } else if (auto updateOp = dyn_cast<xegpu::UpdateOffsetOp>(op)) {
8074 tdescTy = updateOp.getTensorDescType ();
@@ -111,14 +105,40 @@ struct TestXeGPUUnrollingPatterns
111105 Attribute encoding = tdescTy.getEncoding ();
112106 auto layout = llvm::dyn_cast_if_present<xegpu::LayoutAttr>(
113107 tdescTy.getLayout ());
108+
109+ int64_t newChunkSize = 0 ;
110+ auto instData = layout.getInstData ();
111+ if (!instData.empty ())
112+ newChunkSize = instData.asArrayRef ().back ();
113+
114114 if (layout) {
115115 if (layout.getLaneLayout () == nullptr )
116116 layout = xegpu::LayoutAttr ();
117117 else
118118 layout = layout.dropInstData ();
119119 }
120- newTy = xegpu::TensorDescType::get (ctx, tileShape, elemTy, encoding,
121- layout);
120+
121+ SmallVector<NamedAttribute> attrs;
122+ auto scatterAttr = mlir::dyn_cast<xegpu::ScatterTensorDescAttr>(encoding);
123+ if (scatterAttr) {
124+ int64_t chunkSize = scatterAttr.getChunkSize ().getInt ();
125+
126+ if (chunkSize > 1 ) {
127+
128+ auto chunkSizeAttr = mlir::IntegerAttr::get (
129+ mlir::IntegerType::get (ctx, 64 ), newChunkSize);
130+
131+ // To create a new attribute with a different chunk_size:
132+ auto newEncoding = xegpu::ScatterTensorDescAttr::get (
133+ ctx, scatterAttr.getMemorySpace (), chunkSizeAttr);
134+
135+ encoding = newEncoding;
136+
137+ }
138+
139+ }
140+ newTy = xegpu::TensorDescType::get (ctx, tileShape, elemTy, encoding, layout);
141+
122142 } else {
123143 newTy = type.clone (tileShape, elemTy);
124144 }
0 commit comments