Skip to content

Commit 0d758de

Browse files
authored
[MLIR][XeVM] blockload and blockstore ops should use scalar types (#161708)
instead of single element vectors. XeVM type system does not support single element vectors.
1 parent 70c1c8f commit 0d758de

File tree

3 files changed

+19
-12
lines changed

3 files changed

+19
-12
lines changed

mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -190,8 +190,9 @@ def XeVM_StoreCacheControlAttr
190190

191191
def XeVM_BlockLoadOp
192192
: XeVM_Op<"blockload">,
193-
Results<(
194-
outs FixedVectorOfRankAndType<[1], [XeVM_1DBlockElemType]>:$res)>,
193+
Results<(outs AnyTypeOf<
194+
[XeVM_1DBlockElemType,
195+
FixedVectorOfRankAndType<[1], [XeVM_1DBlockElemType]>]>:$res)>,
195196
Arguments<(ins Arg<LLVM_AnyPointer, "", [MemRead]>:$ptr,
196197
OptionalAttr<XeVM_LoadCacheControlAttr>:$cache_control)> {
197198
let summary = "subgroup block load";
@@ -228,7 +229,9 @@ def XeVM_BlockLoadOp
228229
def XeVM_BlockStoreOp
229230
: XeVM_Op<"blockstore">,
230231
Arguments<(ins Arg<LLVM_AnyPointer, "", [MemWrite]>:$ptr,
231-
FixedVectorOfRankAndType<[1], [XeVM_1DBlockElemType]>:$val,
232+
AnyTypeOf<[XeVM_1DBlockElemType,
233+
FixedVectorOfRankAndType<[1],
234+
[XeVM_1DBlockElemType]>]>:$val,
232235
OptionalAttr<XeVM_StoreCacheControlAttr>:$cache_control)> {
233236
let summary = "subgroup block store";
234237
let description = [{

mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -310,26 +310,30 @@ LogicalResult BlockPrefetch2dOp::verify() {
310310
template <typename OpType, typename = std::enable_if_t<llvm::is_one_of<
311311
OpType, BlockLoadOp, BlockStoreOp>::value>>
312312
LogicalResult verify1DBlockArg(OpType op) {
313-
VectorType vTy;
313+
Type srcOrDstTy;
314314
if constexpr (std::is_same_v<OpType, BlockLoadOp>)
315-
vTy = op.getResult().getType();
315+
srcOrDstTy = op.getResult().getType();
316316
else
317-
vTy = op.getVal().getType();
317+
srcOrDstTy = op.getVal().getType();
318+
VectorType vTy = dyn_cast<VectorType>(srcOrDstTy);
319+
// scalar case is always valid
320+
if (!vTy)
321+
return success();
318322
int elemTySize = vTy.getElementType().getIntOrFloatBitWidth() / 8;
319323
if (elemTySize == 1) {
320-
llvm::SmallSet<int, 5> validSizes{1, 2, 4, 8, 16};
324+
llvm::SmallSet<int, 4> validSizes{2, 4, 8, 16};
321325
if (validSizes.contains(vTy.getNumElements()))
322326
return success();
323327
else
324328
return op.emitOpError(
325-
"vector size must be 1, 2, 4, 8 or 16 for 8-bit element type");
329+
"vector size must be 2, 4, 8 or 16 for 8-bit element type");
326330
} else {
327-
llvm::SmallSet<int, 4> validSizes{1, 2, 4, 8};
331+
llvm::SmallSet<int, 3> validSizes{2, 4, 8};
328332
if (validSizes.contains(vTy.getNumElements()))
329333
return success();
330334
else
331335
return op.emitOpError(
332-
"vector size must be 1, 2, 4 or 8 for element type > 8 bits");
336+
"vector size must be 2, 4 or 8 for element type > 8 bits");
333337
}
334338
}
335339

mlir/test/Dialect/LLVMIR/invalid.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1973,14 +1973,14 @@ llvm.func @invalid_xevm_prefetch(%arg0: !llvm.ptr) {
19731973

19741974
// -----
19751975
llvm.func @invalid_xevm_blockload(%arg0: !llvm.ptr<1>) {
1976-
// expected-error@+1 {{op vector size must be 1, 2, 4 or 8 for element type > 8 bits}}
1976+
// expected-error@+1 {{op vector size must be 2, 4 or 8 for element type > 8 bits}}
19771977
%0 = xevm.blockload %arg0 : (!llvm.ptr<1>) -> vector<3xi16>
19781978
llvm.return
19791979
}
19801980

19811981
// -----
19821982
llvm.func @invalid_xevm_blockstore(%arg0: !llvm.ptr<1>, %arg1: vector<5xi8>) {
1983-
// expected-error@+1 {{op vector size must be 1, 2, 4, 8 or 16 for 8-bit element type}}
1983+
// expected-error@+1 {{op vector size must be 2, 4, 8 or 16 for 8-bit element type}}
19841984
xevm.blockstore %arg0, %arg1 : (!llvm.ptr<1>, vector<5xi8>)
19851985
llvm.return
19861986
}

0 commit comments

Comments
 (0)