Skip to content

Commit faa0bfb

Browse files
committed
address comments
1 parent de87d09 commit faa0bfb

File tree

3 files changed

+15
-7
lines changed

3 files changed

+15
-7
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1319,6 +1319,9 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
13191319
Arguments:
13201320
- `mem_desc`: the memory descriptor identifying the SLM region.
13211321
- `offsets`: the coordinates within the matrix to read from.
1322+
- `subgroup_block_io`: [optional] An attribute indicating that the operation can be
1323+
lowered to a subgroup block load. When this attribute is present,
1324+
the offsets are subgroup-uniform across all lanes.
13221325
- `layout`: [optional] An attribute for guiding distributions among
13231326
subgroups and/or work-items. It currently can accept either
13241327
LayoutAttr or SliceAttr.
@@ -1367,6 +1370,9 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>,
13671370
- `mem_desc`: the memory descriptor specifying the SLM region.
13681371
- `offsets`: the coordinates within the matrix where the data will be written.
13691372
- `data`: the values to be stored in the matrix.
1373+
- `subgroup_block_io`: [optional] An attribute indicating that the operation can be
1374+
lowered to a subgroup block store. When this attribute is present,
1375+
the offsets are subgroup-uniform across all lanes.
13701376
- `layout`: [optional] An attribute for guiding distributions among
13711377
subgroups and/or work-items. It currently can accept either
13721378
LayoutAttr or SliceAttr.

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -263,10 +263,10 @@ def XeGPU_MemDesc: XeGPUTypeDef<"MemDesc", "mem_desc", [ShapedTypeInterface], "m
263263
/// based on the rank and the value of the first stride dimension.
264264
bool isColMajor() {
265265
auto dim0 = dyn_cast<IntegerAttr>(getStrideAttr()[0]);
266-
return getRank() == 2 && dim0 && dim0.getInt() == 1;
266+
return getRank() == 2 && dim0.getInt() == 1;
267267
}
268268

269-
// get the Blocking shape for a MemDescType, Which is represented
269+
// Get the Blocking shape for a MemDescType, Which is represented
270270
// as an attribute in MemDescType. By default it is the shape
271271
// of the mdescTy
272272
SmallVector<int64_t> getBlockShape() {
@@ -284,16 +284,18 @@ def XeGPU_MemDesc: XeGPUTypeDef<"MemDesc", "mem_desc", [ShapedTypeInterface], "m
284284
// Get strides as vector of integer.
285285
// If it contains block attribute, the strides are blocked strides.
286286
//
287-
// The blocking is applied against the original matrix shape
288-
// so that the linear offset is not impacted by the subview.
287+
// The blocking is applied to the base matrix shape derived from the
288+
// memory descriptor's stride information. If the matrix described by
289+
// the memory descriptor is not contiguous, it is assumed that the base
290+
// matrix is contiguous and follows the same memory layout.
289291
//
290292
// It first computes the original matrix shape using the stride info,
291293
// then computes the number of blocks in each dimension of original shape,
292294
// then compute the outer block shape and stride,
293295
// then combines the inner and outer block shape and stride
294-
// e.g. for mem_desc<32x256xf16, @block=[16, 8], @strides=[1, 32]>
296+
// e.g. for `mem_desc<32x256xf16, @block=[16, 8], @strides=[1, 32]>`
295297
// its memory layout tuple is ([2,32,16,8],[128,256,1,16])
296-
// for mem_desc<256x32xf16, @block=[8, 16]> with default @stride[32, 1]
298+
// for `mem_desc<256x32xf16, @block=[8, 16]>` with default @stride[32, 1]
297299
// its memory layout tuple is ([32,2,8,16],[256,128,16,1])
298300
SmallVector<int64_t> getStrideShape();
299301

mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -520,7 +520,7 @@ class CreateMemDescOpPattern final
520520
matchAndRewrite(xegpu::CreateMemDescOp op, OpAdaptor adaptor,
521521
ConversionPatternRewriter &rewriter) const override {
522522

523-
auto resTy = cast<xegpu::MemDescType>(op.getResult().getType());
523+
auto resTy = op.getMemDesc();
524524

525525
// Create the result MemRefType with the same shape, element type, and
526526
// memory space

0 commit comments

Comments
 (0)