Skip to content

Commit 546a3f7

Browse files
committed
addverifier for invalid cases
1 parent cd518d2 commit 546a3f7

File tree

3 files changed

+53
-15
lines changed

3 files changed

+53
-15
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -196,17 +196,22 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
196196
return attr;
197197
}
198198

199-
200199
/// wrapper for matching with OffsetSizeAndStrideOpInterface
201200
/// If source is IntegerType or `const_shape` is filled,
202201
/// it will return `const_shape`, such that mixes of `shape`
203202
/// and `const_shape` will be used to represent the shape of
204203
/// source operand. They overide static shape from source memref type.
205204
ArrayRef<int64_t> getStaticSizes() {
205+
/// To be compatible with OffsetSizeAndStrideOpInterface, which expects valid return value and perform checks
206+
static llvm::SmallVector<int64_t, 4> emptyShape;
207+
206208
auto attr = getConstShapeAttr();
207-
if (llvm::isa<IntegerType>(getSourceType()) || attr)
209+
if (attr)
208210
return attr;
209211

212+
if (llvm::isa<IntegerType>(getSourceType()))
213+
return emptyShape;
214+
210215
auto memrefType = llvm::dyn_cast<MemRefType>(getSourceType());
211216
assert(memrefType && "Incorrect use of getStaticSizes");
212217
return memrefType.getShape();
@@ -218,9 +223,15 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
218223
/// and `const_strides` will be used to represent the strides of
219224
/// source operand. They overide static strides from source memref type.
220225
ArrayRef<int64_t> getStaticStrides() {
226+
/// To be compatible with OffsetSizeAndStrideOpInterface, which expects valid return value and perform checks
227+
static llvm::SmallVector<int64_t, 4> emptyStrides;
228+
221229
auto attr = getConstStridesAttr();
222-
if (llvm::isa<IntegerType>(getSourceType()) || attr)
230+
if (attr)
223231
return attr;
232+
233+
if (llvm::isa<IntegerType>(getSourceType()))
234+
return emptyStrides;
224235

225236
auto memrefType = llvm::dyn_cast<MemRefType>(getSourceType());
226237
assert(memrefType && "Incorrect use of getStaticStrides");

mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ isValidGatherScatterParams(Type maskTy, VectorType valueTy,
116116
void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
117117
Type tdesc, TypedValue<MemRefType> source) {
118118
[[maybe_unused]] auto ty = source.getType();
119-
assert(ty.hasStaticShape());
119+
assert(ty.hasStaticShape() && "expecting a memref with static shape");
120120

121121
build(builder, state, tdesc, source, ValueRange({}) /* dynamic offsets */,
122122
ValueRange({}) /* empty dynamic shape */,
@@ -130,7 +130,8 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
130130
Type tdesc, TypedValue<MemRefType> source,
131131
llvm::ArrayRef<OpFoldResult> shape,
132132
llvm::ArrayRef<OpFoldResult> strides) {
133-
assert(shape.size() && strides.size() && shape.size() == strides.size());
133+
assert(shape.size() && strides.size() && shape.size() == strides.size() &&
134+
"Shape and strides must be present and of equal size for ui64 initialization.");
134135

135136
llvm::SmallVector<int64_t> staticShape;
136137
llvm::SmallVector<int64_t> staticStrides;
@@ -152,7 +153,8 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
152153
Type tdesc, TypedValue<IntegerType> source,
153154
llvm::ArrayRef<OpFoldResult> shape,
154155
llvm::ArrayRef<OpFoldResult> strides) {
155-
assert(shape.size() && strides.size() && shape.size() == strides.size());
156+
assert(shape.size() && strides.size() && shape.size() == strides.size() &&
157+
"Shape and strides must be present and of equal size for ui64 initialization.");
156158

157159
llvm::SmallVector<int64_t> staticShape;
158160
llvm::SmallVector<int64_t> staticStrides;
@@ -255,6 +257,13 @@ LogicalResult CreateNdDescOp::verify() {
255257
invalidElemTy |= memrefTy.getElementType() != getElementType();
256258
}
257259

260+
if (llvm::isa<IntegerType>(getSourceType()) ) {
261+
// strides and shape must present for integer source.
262+
if (getMixedStrides().empty() || getMixedSizes().empty())
263+
return emitOpError("Expecting strides and shape to be present for "
264+
"integer source.");
265+
}
266+
258267
// mismatches among shape, strides, and offsets are
259268
// already handeled by OffsetSizeAndStrideOpInterface.
260269
// So they are not check here.
@@ -301,18 +310,21 @@ ParseResult parseOptionalDynamicIndexList(
301310
return failure();
302311
integerVals.push_back(integer);
303312
}
304-
305313
return success();
306314
};
315+
316+
//If the optional values are given there must be left bracket
307317
if (parser.parseOptionalLSquare().succeeded()) {
308318
if (parser.parseCommaSeparatedList(parseIntegerOrValue) ||
309319
parser.parseRSquare())
310320
return parser.emitError(parser.getNameLoc())
311-
<< "expected SSA value or integer";
321+
<< "expected a list of SSA values or integers";
312322
integers = parser.getBuilder().getDenseI64ArrayAttr(integerVals);
313323
return success();
314324
}
325+
315326
return success();
327+
316328
}
317329

318330
void printOptionalDynamicIndexList(OpAsmPrinter &printer, Operation *op,

mlir/test/Dialect/XeGPU/invalid.mlir

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,55 +1,70 @@
11
// RUN: mlir-opt %s -split-input-file -verify-diagnostics
22

33
// -----
4-
func.func @create_nd_tdesc_vc_1(%src: memref<24xf32>) {
4+
func.func @create_nd_tdesc_1(%src: memref<24xf32>) {
55
// expected-error@+1 {{Expecting the TensorDesc rank is not greater than the ranks of shape, strides, offsets or the memref source}}
66
%1 = xegpu.create_nd_tdesc %src[0] : memref<24xf32> -> !xegpu.tensor_desc<8x16xf32>
77
return
88
}
99

1010
// -----
1111

12-
func.func @create_nd_tdesc_vc_2(%src: memref<24x32xf32>) {
12+
func.func @create_nd_tdesc_2(%src: memref<24x32xf32>) {
1313
// expected-error@+1 {{TensorDesc should have the same element type with the source if it is a memref}}
1414
%1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf16>
1515
return
1616
}
1717

1818
// -----
19-
func.func @create_nd_tdesc_vc_3(%src: memref<2x24x32xf32, 3>) {
19+
func.func @create_nd_tdesc_3(%src: memref<2x24x32xf32, 3>) {
2020
// expected-error@+1 {{SLM is only supported for 1D block tensor}}
2121
%1 = xegpu.create_nd_tdesc %src[0, 0, 0] : memref<2x24x32xf32, 3> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_space = slm>>
2222
return
2323
}
2424

2525
// -----
26-
func.func @create_nd_tdesc_vc_4(%src: memref<2x24x32xf32, 3>) {
26+
func.func @create_nd_tdesc_4(%src: memref<2x24x32xf32, 3>) {
2727
// expected-error@+1 {{Memory space mismatch}}
2828
%1 = xegpu.create_nd_tdesc %src[0, 0, 0] : memref<2x24x32xf32, 3> -> !xegpu.tensor_desc<16xf32>
2929
return
3030
}
3131

3232
// -----
33-
func.func @create_nd_tdesc_subgroup_1(%src: memref<128x128xf32>) {
33+
func.func @create_nd_tdesc_5(%src: memref<128x128xf32>) {
3434
// expected-error@+1 {{cannot distribute [128, 128] using #xegpu.layout<sg_layout = [4, 2], sg_data = [24, 48]>}}
3535
%1 = xegpu.create_nd_tdesc %src[0, 0] : memref<128x128xf32> -> !xegpu.tensor_desc<128x128xf32, #xegpu.layout<sg_layout = [4, 2], sg_data = [24, 48]>>
3636
return
3737
}
3838

3939
// -----
40-
func.func @create_nd_tdesc_subgroup_1(%src: memref<128x128xf32>) {
40+
func.func @create_nd_tdesc_6(%src: memref<128x128xf32>) {
4141
// expected-error@+1 {{cannot distribute [128, 128] using #xegpu.layout<sg_layout = [4, 2], sg_data = [32, 64], inst_data = [24, 48]>}}
4242
%1 = xegpu.create_nd_tdesc %src[0, 0] : memref<128x128xf32> -> !xegpu.tensor_desc<128x128xf32, #xegpu.layout<sg_layout = [4, 2], sg_data = [32, 64], inst_data = [24, 48]>>
4343
return
4444
}
4545

4646
// -----
47-
func.func @create_nd_tdesc_subgroup_1(%src: memref<128x128xf32>) {
47+
func.func @create_nd_tdesc_7(%src: memref<128x128xf32>) {
4848
// expected-error@+1 {{cannot distribute [128, 128] using #xegpu.layout<sg_layout = [4, 2], sg_data = [32, 64], inst_data = [64, 32]>}}
4949
%1 = xegpu.create_nd_tdesc %src[0, 0] : memref<128x128xf32> -> !xegpu.tensor_desc<128x128xf32, #xegpu.layout<sg_layout = [4, 2], sg_data = [32, 64], inst_data = [64, 32]>>
5050
return
5151
}
5252

53+
// -----
54+
func.func @create_nd_tdesc_8(%src: ui64) {
55+
// expected-error@+1 {{'xegpu.create_nd_tdesc' op Expecting strides and shape to be present for integer source}}
56+
%1 = xegpu.create_nd_tdesc %src : ui64-> !xegpu.tensor_desc<128x128xf32>
57+
return
58+
}
59+
60+
// -----
61+
func.func @create_nd_tdesc_9(%src: ui64) {
62+
// expected-error@+1 {{expected mixed offsets rank to match mixed sizes rank}}
63+
%1 = xegpu.create_nd_tdesc %src[0, 0] : ui64-> !xegpu.tensor_desc<128x128xf32>
64+
return
65+
}
66+
67+
5368
// -----
5469
func.func @prefetch_nd_vc_1(%src: memref<24x32xf16>) {
5570
%1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16>

0 commit comments

Comments
 (0)