@@ -242,7 +242,6 @@ def XeGPU_MemDesc: XeGPUTypeDef<"MemDesc", "mem_desc", [ShapedTypeInterface], "m
242242 if (layout && layout.hasAttr("stride")) {
243243 return layout.getStrides();
244244 }
245-
246245 // derive and return default strides
247246 SmallVector<int64_t> defaultStrides;
248247 llvm::append_range(defaultStrides, getShape().drop_front());
@@ -251,6 +250,15 @@ def XeGPU_MemDesc: XeGPUTypeDef<"MemDesc", "mem_desc", [ShapedTypeInterface], "m
251250 return builder.getI64ArrayAttr(defaultStrides);
252251 }
253252
253+ ArrayAttr getBlockAttr() {
254+ auto layout = getMemLayout();
255+ if (layout && layout.hasAttr("block")) {
256+ return layout.getBlockAttr();
257+ }
258+ Builder builder(getContext());
259+ return builder.getI64ArrayAttr({});
260+ }
261+
254262 /// Heuristic to determine if the MemDesc uses column-major layout,
255263 /// based on the rank and the value of the first stride dimension.
256264 bool isColMajor() {
@@ -261,16 +269,14 @@ def XeGPU_MemDesc: XeGPUTypeDef<"MemDesc", "mem_desc", [ShapedTypeInterface], "m
261269 // get the Blocking shape for a MemDescType, Which is represented
262270 // as an attribute in MemDescType. By default it is the shape
263271 // of the mdescTy
264- SmallVector<int64_t> getBlockSize () {
272+ SmallVector<int64_t> getBlockShape () {
265273 SmallVector<int64_t> size(getShape());
266- MemLayoutAttr layout = getMemLayout();
267- if (layout && layout.hasAttr("block")) {
268- ArrayAttr attr = layout.getBlockAttr();
274+ ArrayAttr blockAttr = getBlockAttr();
275+ if (!blockAttr.empty()) {
269276 size.clear();
270- llvm::for_each(attr, [&](Attribute elem) {
271- if (auto intElem = dyn_cast<IntegerAttr>(elem))
272- size.push_back(intElem.getInt());
273- });
277+ for (auto attr : blockAttr.getValue()) {
278+ size.push_back(cast<IntegerAttr>(attr).getInt());
279+ }
274280 }
275281 return size;
276282 }
@@ -289,7 +295,7 @@ def XeGPU_MemDesc: XeGPUTypeDef<"MemDesc", "mem_desc", [ShapedTypeInterface], "m
289295 // its memory layout tuple is ([2,32,16,8],[128,256,1,16])
290296 // for mem_desc<256x32xf16, @block=[8, 16]> with default @stride[32, 1]
291297 // its memory layout tuple is ([32,2,8,16],[256,128,16,1])
292- SmallVector<int64_t> getStrides ();
298+ SmallVector<int64_t> getStrideShape ();
293299
294300 /// Generates instructions to compute the linearize offset
295301 // if the memory descriptor is blocked, it returns linearize offset based on the blocked layout
0 commit comments