Skip to content

Commit 0021a6b

Browse files
authored
[MLIR][XeVM] Add xevm blockload and blockstore op definition. (#158118)
Add op definition for subgroup block load and store ops: xevm.blockload and xevm.blockstore links to related specs: cl_intel_subgroup: https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroups.html#_add_a_new_section_6_13_x_sub_group_read_and_write_functions cl_intel_subgroup_local_block_io: https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_local_block_io.html SPV_INTEL_subgroups: https://github.khronos.org/SPIRV-Registry/extensions/INTEL/SPV_INTEL_subgroups.html
1 parent a3762fb commit 0021a6b

File tree

4 files changed

+144
-0
lines changed

4 files changed

+144
-0
lines changed

mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ class XeVM_Op<string mnemonic, list<Trait> traits = []>
6969
}
7070

7171
def XeVM_ElemType : AnyTypeOf<[AnyI8, AnyI16, AnyI32, F32, TF32, F16, BF16]>;
72+
def XeVM_1DBlockElemType : AnyTypeOf<[I8, I16, I32, I64]>;
7273

7374
//===----------------------------------------------------------------------===//
7475
// XeVM Load Cache Control
@@ -187,6 +188,81 @@ def XeVM_StoreCacheControlAttr
187188
let assemblyFormat = "`<` $value `>`";
188189
}
189190

191+
def XeVM_BlockLoadOp
192+
: XeVM_Op<"blockload">,
193+
Results<(
194+
outs FixedVectorOfRankAndType<[1], [XeVM_1DBlockElemType]>:$res)>,
195+
Arguments<(ins Arg<LLVM_AnyPointer, "", [MemRead]>:$ptr,
196+
OptionalAttr<XeVM_LoadCacheControlAttr>:$cache_control)> {
197+
let summary = "subgroup block load";
198+
let description = [{
199+
Reads one or more components of Result data for each invocation
200+
in the subgroup from the specified `ptr` as a block operation.
201+
The data is read strided, so the first value read is:
202+
```
203+
ptr[ SubgroupLocalInvocationId ]
204+
```
205+
and the second value read is:
206+
```
207+
ptr[ SubgroupLocalInvocationId + SubgroupMaxSize ]
208+
```
209+
Result type may be a scalar or vector type of scalar element type.
210+
211+
The parameters are:
212+
* `ptr` - the base address to load from. Must be uniform across subgroup.
213+
* `cache_control` - an enumerator that sets the cache behaviour
214+
215+
Example:
216+
```mlir
217+
%loaded_a = xevm.blockload %src,
218+
<{cache_control=#xevm.load_cache_control<L1uc_L2uc_L3uc>}>
219+
: (!llvm.ptr<1>) -> vector<4xi16>
220+
```
221+
}];
222+
let assemblyFormat = [{
223+
operands prop-dict attr-dict `:` functional-type(operands, results)
224+
}];
225+
let hasVerifier = 1;
226+
}
227+
228+
def XeVM_BlockStoreOp
229+
: XeVM_Op<"blockstore">,
230+
Arguments<(ins Arg<LLVM_AnyPointer, "", [MemWrite]>:$ptr,
231+
FixedVectorOfRankAndType<[1], [XeVM_1DBlockElemType]>:$val,
232+
OptionalAttr<XeVM_StoreCacheControlAttr>:$cache_control)> {
233+
let summary = "subgroup block store";
234+
let description = [{
235+
Writes one or more components of `val` for each invocation
236+
in the subgroup to the specified `ptr` as a block operation.
237+
The data is written strided, so the first value is written to:
238+
```
239+
ptr[ SubgroupLocalInvocationId ]
240+
```
241+
and the second value is written to:
242+
```
243+
ptr[ SubgroupLocalInvocationId + SubgroupMaxSize ]
244+
```
245+
`val` type may be a scalar or vector type of scalar element type.
246+
247+
The parameters are:
248+
* `ptr` - the base address to store to. Must be uniform across subgroup.
249+
* `val` - the value to store
250+
* `cache_control` - an enumerator that sets the cache behaviour
251+
252+
Example:
253+
```mlir
254+
xevm.blockstore %ptr, %val
255+
<{cache_control=#xevm.store_cache_control<L1uc_L2uc_L3uc>}>
256+
: (!llvm.ptr<1>, vector<4xi16>)
257+
```
258+
}];
259+
260+
let assemblyFormat = [{
261+
operands prop-dict attr-dict `:` `(` type(operands) `)`
262+
}];
263+
let hasVerifier = 1;
264+
}
265+
190266
def XeVM_BlockLoad2dOp
191267
: XeVM_Op<"blockload2d">,
192268
Results<(outs FixedVectorOfRankAndType<[1], [XeVM_ElemType]>:$res)>,

mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "mlir/Dialect/GPU/IR/CompilationInterfaces.h"
1010
#include "mlir/Dialect/Utils/StaticValueUtils.h"
1111
#include "mlir/IR/DialectImplementation.h"
12+
#include "llvm/ADT/SmallSet.h"
1213
#include "llvm/ADT/TypeSwitch.h"
1314
#include "llvm/Support/FileSystem.h"
1415
#include "llvm/Support/MathExtras.h"
@@ -306,6 +307,36 @@ LogicalResult BlockPrefetch2dOp::verify() {
306307
return success();
307308
}
308309

310+
template <typename OpType, typename = std::enable_if_t<llvm::is_one_of<
311+
OpType, BlockLoadOp, BlockStoreOp>::value>>
312+
LogicalResult verify1DBlockArg(OpType op) {
313+
VectorType vTy;
314+
if constexpr (std::is_same_v<OpType, BlockLoadOp>)
315+
vTy = op.getResult().getType();
316+
else
317+
vTy = op.getVal().getType();
318+
int elemTySize = vTy.getElementType().getIntOrFloatBitWidth() / 8;
319+
if (elemTySize == 1) {
320+
llvm::SmallSet<int, 5> validSizes{1, 2, 4, 8, 16};
321+
if (validSizes.contains(vTy.getNumElements()))
322+
return success();
323+
else
324+
return op.emitOpError(
325+
"vector size must be 1, 2, 4, 8 or 16 for 8-bit element type");
326+
} else {
327+
llvm::SmallSet<int, 4> validSizes{1, 2, 4, 8};
328+
if (validSizes.contains(vTy.getNumElements()))
329+
return success();
330+
else
331+
return op.emitOpError(
332+
"vector size must be 1, 2, 4 or 8 for element type > 8 bits");
333+
}
334+
}
335+
336+
LogicalResult BlockLoadOp::verify() { return verify1DBlockArg(*this); }
337+
338+
LogicalResult BlockStoreOp::verify() { return verify1DBlockArg(*this); }
339+
309340
LogicalResult MMAOp::verify() {
310341
if (getC()) {
311342
if (getResult().getType() != getC().getType())

mlir/test/Dialect/LLVMIR/invalid.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1972,6 +1972,20 @@ llvm.func @invalid_xevm_prefetch(%arg0: !llvm.ptr) {
19721972
llvm.return
19731973
}
19741974

1975+
// -----
1976+
llvm.func @invalid_xevm_blockload(%arg0: !llvm.ptr<1>) {
1977+
// expected-error@+1 {{op vector size must be 1, 2, 4 or 8 for element type > 8 bits}}
1978+
%0 = xevm.blockload %arg0 : (!llvm.ptr<1>) -> vector<3xi16>
1979+
llvm.return
1980+
}
1981+
1982+
// -----
1983+
llvm.func @invalid_xevm_blockstore(%arg0: !llvm.ptr<1>, %arg1: vector<5xi8>) {
1984+
// expected-error@+1 {{op vector size must be 1, 2, 4, 8 or 16 for 8-bit element type}}
1985+
xevm.blockstore %arg0, %arg1 : (!llvm.ptr<1>, vector<5xi8>)
1986+
llvm.return
1987+
}
1988+
19751989
// -----
19761990

19771991
llvm.func @invalid_xevm_mma(%loaded_c_casted: vector<4xf32>, %loaded_a: vector<8xi16>, %loaded_b_casted: vector<8xi32>) -> vector<8xf32> {

mlir/test/Dialect/LLVMIR/xevm.mlir

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,29 @@ func.func @blockprefetch2d(%ptr: !llvm.ptr<1>, %base_width: i32, %base_height: i
5858
return
5959
}
6060

61+
// -----
62+
// CHECK-LABEL: func.func @blockload(
63+
// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr<1>)
64+
func.func @blockload(%ptr: !llvm.ptr<1>) -> vector<4xi16> {
65+
// CHECK: %[[VAR0:.*]] = xevm.blockload %[[ARG0]]
66+
// CHECK-SAME: cache_control = #xevm.load_cache_control<L1uc_L2uc_L3uc>
67+
// CHECK-SAME: (!llvm.ptr<1>) -> vector<4xi16>
68+
%loaded = xevm.blockload %ptr <{cache_control=#xevm.load_cache_control<L1uc_L2uc_L3uc>}>
69+
: (!llvm.ptr<1>) -> vector<4xi16>
70+
return %loaded : vector<4xi16>
71+
}
72+
73+
// -----
74+
// CHECK-LABEL: func.func @blockstore(
75+
// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr<1>,
76+
// CHECK-SAME: %[[ARG1:.*]]: vector<4xi32>)
77+
func.func @blockstore(%ptr: !llvm.ptr<1>, %value: vector<4xi32>) {
78+
// CHECK: xevm.blockstore %[[ARG0]], %[[ARG1]]
79+
// CHECK-SAME: (!llvm.ptr<1>, vector<4xi32>)
80+
xevm.blockstore %ptr, %value : (!llvm.ptr<1>, vector<4xi32>)
81+
return
82+
}
83+
6184
// -----
6285
// CHECK-LABEL: func.func @mma(
6386
// CHECK-SAME: %[[ARG0:.*]]: vector<8xf32>, %[[ARG1:.*]]: vector<8xi16>, %[[ARG2:.*]]: vector<8xi32>)

0 commit comments

Comments
 (0)