Skip to content

Commit 56d8f73

Browse files
committed
Add element type and vector size restrictions.
1 parent aa5b677 commit 56d8f73

File tree

3 files changed

+51
-2
lines changed

3 files changed

+51
-2
lines changed

mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ class XeVM_Op<string mnemonic, list<Trait> traits = []>
6969
}
7070

7171
def XeVM_ElemType : AnyTypeOf<[AnyI8, AnyI16, AnyI32, F32, TF32, F16, BF16]>;
72+
def XeVM_1DBlockElemType : AnyTypeOf<[I8, I16, I32, I64]>;
7273

7374
//===----------------------------------------------------------------------===//
7475
// XeVM Load Cache Control
@@ -189,7 +190,8 @@ def XeVM_StoreCacheControlAttr
189190

190191
def XeVM_BlockLoadOp
191192
: XeVM_Op<"blockload">,
192-
Results<(outs FixedVectorOfRankAndType<[1], [XeVM_ElemType]>:$res)>,
193+
Results<(
194+
outs FixedVectorOfRankAndType<[1], [XeVM_1DBlockElemType]>:$res)>,
193195
Arguments<(ins Arg<LLVM_AnyPointer, "", [MemRead]>:$ptr,
194196
OptionalAttr<XeVM_LoadCacheControlAttr>:$cache_control)> {
195197
let summary = "subgroup block load";
@@ -220,12 +222,13 @@ def XeVM_BlockLoadOp
220222
let assemblyFormat = [{
221223
operands prop-dict attr-dict `:` functional-type(operands, results)
222224
}];
225+
let hasVerifier = 1;
223226
}
224227

225228
def XeVM_BlockStoreOp
226229
: XeVM_Op<"blockstore">,
227230
Arguments<(ins Arg<LLVM_AnyPointer, "", [MemWrite]>:$ptr,
228-
FixedVectorOfRankAndType<[1], [XeVM_ElemType]>:$val,
231+
FixedVectorOfRankAndType<[1], [XeVM_1DBlockElemType]>:$val,
229232
OptionalAttr<XeVM_StoreCacheControlAttr>:$cache_control)> {
230233
let summary = "subgroup block store";
231234
let description = [{
@@ -257,6 +260,7 @@ def XeVM_BlockStoreOp
257260
let assemblyFormat = [{
258261
operands prop-dict attr-dict `:` `(` type(operands) `)`
259262
}];
263+
let hasVerifier = 1;
260264
}
261265

262266
def XeVM_BlockLoad2dOp

mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "mlir/Dialect/GPU/IR/CompilationInterfaces.h"
1010
#include "mlir/Dialect/Utils/StaticValueUtils.h"
1111
#include "mlir/IR/DialectImplementation.h"
12+
#include "llvm/ADT/SmallSet.h"
1213
#include "llvm/ADT/TypeSwitch.h"
1314
#include "llvm/Support/FileSystem.h"
1415
#include "llvm/Support/MathExtras.h"
@@ -306,6 +307,36 @@ LogicalResult BlockPrefetch2dOp::verify() {
306307
return success();
307308
}
308309

310+
template <typename OpType, typename = std::enable_if_t<llvm::is_one_of<
311+
OpType, BlockLoadOp, BlockStoreOp>::value>>
312+
LogicalResult verify1DBlockArg(OpType op) {
313+
VectorType vTy;
314+
if constexpr (std::is_same_v<OpType, BlockLoadOp>)
315+
vTy = op.getResult().getType();
316+
else
317+
vTy = op.getVal().getType();
318+
int elemTySize = vTy.getElementType().getIntOrFloatBitWidth() / 8;
319+
if (elemTySize == 1) {
320+
llvm::SmallSet<int, 5> validSizes{1, 2, 4, 8, 16};
321+
if (validSizes.contains(vTy.getNumElements()))
322+
return success();
323+
else
324+
return op.emitOpError(
325+
"vector size must be 1, 2, 4, 8 or 16 for 8-bit element type");
326+
} else {
327+
llvm::SmallSet<int, 4> validSizes{1, 2, 4, 8};
328+
if (validSizes.contains(vTy.getNumElements()))
329+
return success();
330+
else
331+
return op.emitOpError(
332+
"vector size must be 1, 2, 4 or 8 for element type > 8 bits");
333+
}
334+
}
335+
336+
LogicalResult BlockLoadOp::verify() { return verify1DBlockArg(*this); }
337+
338+
LogicalResult BlockStoreOp::verify() { return verify1DBlockArg(*this); }
339+
309340
LogicalResult MMAOp::verify() {
310341
if (getC()) {
311342
if (getResult().getType() != getC().getType())

mlir/test/Dialect/LLVMIR/invalid.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1972,6 +1972,20 @@ llvm.func @invalid_xevm_prefetch(%arg0: !llvm.ptr) {
19721972
llvm.return
19731973
}
19741974

1975+
// -----
1976+
llvm.func @invalid_xevm_blockload(%arg0: !llvm.ptr<1>) {
1977+
// expected-error@+1 {{op vector size must be 1, 2, 4 or 8 for element type > 8 bits}}
1978+
%0 = xevm.blockload %arg0 : (!llvm.ptr<1>) -> vector<3xi16>
1979+
llvm.return
1980+
}
1981+
1982+
// -----
1983+
llvm.func @invalid_xevm_blockstore(%arg0: !llvm.ptr<1>, %arg1: vector<5xi8>) {
1984+
// expected-error@+1 {{op vector size must be 1, 2, 4, 8 or 16 for 8-bit element type}}
1985+
xevm.blockstore %arg0, %arg1 : (!llvm.ptr<1>, vector<5xi8>)
1986+
llvm.return
1987+
}
1988+
19751989
// -----
19761990

19771991
llvm.func @invalid_xevm_mma(%loaded_c_casted: vector<4xf32>, %loaded_a: vector<8xi16>, %loaded_b_casted: vector<8xi32>) -> vector<8xf32> {

0 commit comments

Comments
 (0)