Skip to content

Commit 272f512

Browse files
committed
address comments
1 parent 966525b commit 272f512

File tree

6 files changed

+27
-17
lines changed

6 files changed

+27
-17
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,6 @@ def XeGPU_MemDesc: XeGPUTypeDef<"MemDesc", "mem_desc", [ShapedTypeInterface], "m
242242
if (layout && layout.hasAttr("stride")) {
243243
return layout.getStrides();
244244
}
245-
246245
// derive and return default strides
247246
SmallVector<int64_t> defaultStrides;
248247
llvm::append_range(defaultStrides, getShape().drop_front());
@@ -251,6 +250,15 @@ def XeGPU_MemDesc: XeGPUTypeDef<"MemDesc", "mem_desc", [ShapedTypeInterface], "m
251250
return builder.getI64ArrayAttr(defaultStrides);
252251
}
253252

253+
ArrayAttr getBlockAttr() {
254+
auto layout = getMemLayout();
255+
if (layout && layout.hasAttr("block")) {
256+
return layout.getBlockAttr();
257+
}
258+
Builder builder(getContext());
259+
return builder.getI64ArrayAttr({});
260+
}
261+
254262
/// Heuristic to determine if the MemDesc uses column-major layout,
255263
/// based on the rank and the value of the first stride dimension.
256264
bool isColMajor() {
@@ -261,16 +269,14 @@ def XeGPU_MemDesc: XeGPUTypeDef<"MemDesc", "mem_desc", [ShapedTypeInterface], "m
261269
// get the Blocking shape for a MemDescType, Which is represented
262270
// as an attribute in MemDescType. By default it is the shape
263271
// of the mdescTy
264-
SmallVector<int64_t> getBlockSize() {
272+
SmallVector<int64_t> getBlockShape() {
265273
SmallVector<int64_t> size(getShape());
266-
MemLayoutAttr layout = getMemLayout();
267-
if (layout && layout.hasAttr("block")) {
268-
ArrayAttr attr = layout.getBlockAttr();
274+
ArrayAttr blockAttr = getBlockAttr();
275+
if (!blockAttr.empty()) {
269276
size.clear();
270-
llvm::for_each(attr, [&](Attribute elem) {
271-
if (auto intElem = dyn_cast<IntegerAttr>(elem))
272-
size.push_back(intElem.getInt());
273-
});
277+
for (auto attr : blockAttr.getValue()) {
278+
size.push_back(cast<IntegerAttr>(attr).getInt());
279+
}
274280
}
275281
return size;
276282
}
@@ -289,7 +295,7 @@ def XeGPU_MemDesc: XeGPUTypeDef<"MemDesc", "mem_desc", [ShapedTypeInterface], "m
289295
// its memory layout tuple is ([2,32,16,8],[128,256,1,16])
290296
// for mem_desc<256x32xf16, @block=[8, 16]> with default @stride[32, 1]
291297
// its memory layout tuple is ([32,2,8,16],[256,128,16,1])
292-
SmallVector<int64_t> getStrides();
298+
SmallVector<int64_t> getStrideShape();
293299

294300
/// Generates instructions to compute the linearize offset
295301
// if the memory descriptor is blocked, it returns linearize offset based on the blocked layout

mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -776,7 +776,7 @@ SmallVector<OpFoldResult> getBlockedOffsets(OpBuilder &builder, Location loc,
776776
}
777777

778778
// Get strides as vector of integer for MemDesc.
779-
SmallVector<int64_t> MemDescType::getStrides() {
779+
SmallVector<int64_t> MemDescType::getStrideShape() {
780780

781781
SmallVector<int64_t> matrixShape(getShape().begin(), getShape().end());
782782

@@ -786,7 +786,7 @@ SmallVector<int64_t> MemDescType::getStrides() {
786786
strides.push_back(cast<IntegerAttr>(attr).getInt());
787787
}
788788

789-
SmallVector<int64_t> innerBlkShape = getBlockSize();
789+
SmallVector<int64_t> innerBlkShape = getBlockShape();
790790

791791
// get perm from FCD to LCD
792792
// perm[i] = the dim with i-th smallest stride
@@ -837,8 +837,8 @@ Value MemDescType::getLinearOffsets(OpBuilder &builder, Location loc,
837837
ArrayRef<OpFoldResult> offsets) {
838838

839839
SmallVector<int64_t> matrixShape(getShape().begin(), getShape().end());
840-
SmallVector<int64_t> blockShape = getBlockSize();
841-
SmallVector<int64_t> strides = getStrides();
840+
SmallVector<int64_t> blockShape = getBlockShape();
841+
SmallVector<int64_t> strides = getStrideShape();
842842

843843
// blockshape equal to matrixshape means no blocking
844844
if (llvm::equal(blockShape, matrixShape)) {

mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,10 +201,10 @@ IsValidStoreMatrixParams(VectorType dataTy, MemDescType mdescTy,
201201
return emitError() << "data shape must not exceed mem_desc shape.";
202202
} else if (dataShape.size() == 1) {
203203

204-
SmallVector<int64_t> blockSize = mdescTy.getBlockSize();
204+
SmallVector<int64_t> blockShape = mdescTy.getBlockShape();
205205
// if the subgroup_block_io attribute is set, mdescTy must have block
206206
// attribute
207-
if (subgroup_block_io && !blockSize.size())
207+
if (subgroup_block_io && !blockShape.size())
208208
return emitError() << "mem_desc must have block attribute when "
209209
"subgroup_block_io is set.";
210210
// if the subgroup_block_io attribute is set, the memdesc should be row

mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -942,6 +942,8 @@ struct UnrollLoadMatrixOp : public UnrollPattern<xegpu::LoadMatrixOp> {
942942
PatternRewriter &rewriter) const override {
943943
Location loc = op.getLoc();
944944
VectorType valueTy = llvm::dyn_cast<VectorType>(op.getType());
945+
assert(valueTy && "the value type must be vector type!");
946+
945947
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
946948
if (!targetShape || targetShape->size() != (size_t)valueTy.getRank())
947949
return failure();
@@ -985,6 +987,7 @@ struct UnrollStoreMatrixOp : public UnrollPattern<xegpu::StoreMatrixOp> {
985987

986988
Location loc = op.getLoc();
987989
VectorType valueTy = llvm::dyn_cast<VectorType>(op.getData().getType());
990+
assert(valueTy && "the value type must be vector type!");
988991
ArrayRef<int64_t> shape = valueTy.getShape();
989992
auto layout = dyn_cast<xegpu::LayoutAttr>(op.getLayoutAttr());
990993

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -992,6 +992,7 @@ struct WgToSgLoadMatrixOp : public OpConversionPattern<xegpu::LoadMatrixOp> {
992992

993993
ArrayRef<int64_t> wgShape = op.getDataShape();
994994
VectorType valueTy = llvm::dyn_cast<VectorType>(op.getRes().getType());
995+
assert(valueTy && "the value type must be vector type!");
995996
Type elemTy = valueTy.getElementType();
996997

997998
xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();

mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,4 +198,4 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
198198
gpu.return %1: vector<8xf16>
199199
}
200200

201-
}
201+
}

0 commit comments

Comments
 (0)