@@ -263,10 +263,10 @@ def XeGPU_MemDesc: XeGPUTypeDef<"MemDesc", "mem_desc", [ShapedTypeInterface], "m
263263 /// based on the rank and the value of the first stride dimension.
264264 bool isColMajor() {
265265 auto dim0 = dyn_cast<IntegerAttr>(getStrideAttr()[0]);
266- return getRank() == 2 && dim0 && dim0 .getInt() == 1;
266+ return getRank() == 2 && dim0.getInt() == 1;
267267 }
268268
269- // get the Blocking shape for a MemDescType, Which is represented
269+ // Get the Blocking shape for a MemDescType, Which is represented
270270 // as an attribute in MemDescType. By default it is the shape
271271 // of the mdescTy
272272 SmallVector<int64_t> getBlockShape() {
@@ -284,16 +284,18 @@ def XeGPU_MemDesc: XeGPUTypeDef<"MemDesc", "mem_desc", [ShapedTypeInterface], "m
284284 // Get strides as vector of integer.
285285 // If it contains block attribute, the strides are blocked strides.
286286 //
287- // The blocking is applied against the original matrix shape
288- // so that the linear offset is not impacted by the subview.
287+ // The blocking is applied to the base matrix shape derived from the
288+ // memory descriptor's stride information. If the matrix described by
289+ // the memory descriptor is not contiguous, it is assumed that the base
290+ // matrix is contiguous and follows the same memory layout.
289291 //
290292 // It first computes the original matrix shape using the stride info,
291293 // then computes the number of blocks in each dimension of original shape,
292294 // then compute the outer block shape and stride,
293295 // then combines the inner and outer block shape and stride
294- // e.g. for mem_desc<32x256xf16, @block=[16, 8], @strides=[1, 32]>
296+ // e.g. for ` mem_desc<32x256xf16, @block=[16, 8], @strides=[1, 32]>`
295297 // its memory layout tuple is ([2,32,16,8],[128,256,1,16])
296- // for mem_desc<256x32xf16, @block=[8, 16]> with default @stride[32, 1]
298+ // for ` mem_desc<256x32xf16, @block=[8, 16]>` with default @stride[32, 1]
297299 // its memory layout tuple is ([32,2,8,16],[256,128,16,1])
298300 SmallVector<int64_t> getStrideShape();
299301
0 commit comments