Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 76 additions & 0 deletions mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ class XeVM_Op<string mnemonic, list<Trait> traits = []>
}

def XeVM_ElemType : AnyTypeOf<[AnyI8, AnyI16, AnyI32, F32, TF32, F16, BF16]>;
def XeVM_1DBlockElemType : AnyTypeOf<[I8, I16, I32, I64]>;

//===----------------------------------------------------------------------===//
// XeVM Load Cache Control
Expand Down Expand Up @@ -187,6 +188,81 @@ def XeVM_StoreCacheControlAttr
let assemblyFormat = "`<` $value `>`";
}

def XeVM_BlockLoadOp
: XeVM_Op<"blockload">,
Results<(
outs FixedVectorOfRankAndType<[1], [XeVM_1DBlockElemType]>:$res)>,
Arguments<(ins Arg<LLVM_AnyPointer, "", [MemRead]>:$ptr,
OptionalAttr<XeVM_LoadCacheControlAttr>:$cache_control)> {
let summary = "subgroup block load";
let description = [{
Reads one or more components of Result data for each invocation
in the subgroup from the specified `ptr` as a block operation.
The data is read strided, so the first value read is:
```
ptr[ SubgroupLocalInvocationId ]
```
and the second value read is:
```
ptr[ SubgroupLocalInvocationId + SubgroupMaxSize ]
```
Result type may be a scalar or vector type of scalar element type.

The parameters are:
* `ptr` - the base address to load from
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about adding: it must be uniform across subgroup.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated description.

* `cache_control` - an enumerator that sets the cache behaviour

Example:
```mlir
%loaded_a = xevm.blockload %src,
<{cache_control=#xevm.load_cache_control<L1uc_L2uc_L3uc>}>
: (!llvm.ptr<1>) -> vector<4xi16>
```
}];
let assemblyFormat = [{
operands prop-dict attr-dict `:` functional-type(operands, results)
}];
let hasVerifier = 1;
}

def XeVM_BlockStoreOp
: XeVM_Op<"blockstore">,
Arguments<(ins Arg<LLVM_AnyPointer, "", [MemWrite]>:$ptr,
FixedVectorOfRankAndType<[1], [XeVM_1DBlockElemType]>:$val,
OptionalAttr<XeVM_StoreCacheControlAttr>:$cache_control)> {
let summary = "subgroup block store";
let description = [{
Writes one or more components of `val` for each invocation
in the subgroup to the specified `ptr` as a block operation.
The data is written strided, so the first value is written to:
```
ptr[ SubgroupLocalInvocationId ]
```
and the second value is written to:
```
ptr[ SubgroupLocalInvocationId + SubgroupMaxSize ]
```
`val` type may be a scalar or vector type of scalar element type.

The parameters are:
* `ptr` - the base address to store to
* `val` - the value to store
* `cache_control` - an enumerator that sets the cache behaviour

Example:
```mlir
xevm.blockstore %ptr, %val
<{cache_control=#xevm.store_cache_control<L1uc_L2uc_L3uc>}>
: (!llvm.ptr<1>, vector<4xi16>)
```
}];

let assemblyFormat = [{
operands prop-dict attr-dict `:` `(` type(operands) `)`
}];
let hasVerifier = 1;
}

def XeVM_BlockLoad2dOp
: XeVM_Op<"blockload2d">,
Results<(outs FixedVectorOfRankAndType<[1], [XeVM_ElemType]>:$res)>,
Expand Down
31 changes: 31 additions & 0 deletions mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "mlir/Dialect/GPU/IR/CompilationInterfaces.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/DialectImplementation.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/FileSystem.h"
#include "llvm/Support/MathExtras.h"
Expand Down Expand Up @@ -306,6 +307,36 @@ LogicalResult BlockPrefetch2dOp::verify() {
return success();
}

template <typename OpType, typename = std::enable_if_t<llvm::is_one_of<
OpType, BlockLoadOp, BlockStoreOp>::value>>
LogicalResult verify1DBlockArg(OpType op) {
VectorType vTy;
if constexpr (std::is_same_v<OpType, BlockLoadOp>)
vTy = op.getResult().getType();
else
vTy = op.getVal().getType();
int elemTySize = vTy.getElementType().getIntOrFloatBitWidth() / 8;
if (elemTySize == 1) {
llvm::SmallSet<int, 5> validSizes{1, 2, 4, 8, 16};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: seems target specific? add a TODO or move to a dedicated location for HW specifics.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not target arch or chip specific but the restrictions are OpenCL / SPIR-V Intel extensions specific.
In that sense, it applies to all Intel HW and not target specific.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I put links to related specs above.

if (validSizes.contains(vTy.getNumElements()))
return success();
else
return op.emitOpError(
"vector size must be 1, 2, 4, 8 or 16 for 8-bit element type");
} else {
llvm::SmallSet<int, 4> validSizes{1, 2, 4, 8};
if (validSizes.contains(vTy.getNumElements()))
return success();
else
return op.emitOpError(
"vector size must be 1, 2, 4 or 8 for element type > 8 bits");
}
}

LogicalResult BlockLoadOp::verify() { return verify1DBlockArg(*this); }

LogicalResult BlockStoreOp::verify() { return verify1DBlockArg(*this); }

LogicalResult MMAOp::verify() {
if (getC()) {
if (getResult().getType() != getC().getType())
Expand Down
14 changes: 14 additions & 0 deletions mlir/test/Dialect/LLVMIR/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1972,6 +1972,20 @@ llvm.func @invalid_xevm_prefetch(%arg0: !llvm.ptr) {
llvm.return
}

// -----
llvm.func @invalid_xevm_blockload(%arg0: !llvm.ptr<1>) {
// expected-error@+1 {{op vector size must be 1, 2, 4 or 8 for element type > 8 bits}}
%0 = xevm.blockload %arg0 : (!llvm.ptr<1>) -> vector<3xi16>
llvm.return
}

// -----
llvm.func @invalid_xevm_blockstore(%arg0: !llvm.ptr<1>, %arg1: vector<5xi8>) {
// expected-error@+1 {{op vector size must be 1, 2, 4, 8 or 16 for 8-bit element type}}
xevm.blockstore %arg0, %arg1 : (!llvm.ptr<1>, vector<5xi8>)
llvm.return
}

// -----

llvm.func @invalid_xevm_mma(%loaded_c_casted: vector<4xf32>, %loaded_a: vector<8xi16>, %loaded_b_casted: vector<8xi32>) -> vector<8xf32> {
Expand Down
23 changes: 23 additions & 0 deletions mlir/test/Dialect/LLVMIR/xevm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,29 @@ func.func @blockprefetch2d(%ptr: !llvm.ptr<1>, %base_width: i32, %base_height: i
return
}

// -----
// CHECK-LABEL: func.func @blockload(
// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr<1>)
func.func @blockload(%ptr: !llvm.ptr<1>) -> vector<4xi16> {
// CHECK: %[[VAR0:.*]] = xevm.blockload %[[ARG0]]
// CHECK-SAME: cache_control = #xevm.load_cache_control<L1uc_L2uc_L3uc>
// CHECK-SAME: (!llvm.ptr<1>) -> vector<4xi16>
%loaded = xevm.blockload %ptr <{cache_control=#xevm.load_cache_control<L1uc_L2uc_L3uc>}>
: (!llvm.ptr<1>) -> vector<4xi16>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is output not a multiple of SG size?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Output is distributed to work item lanes.
The vector size represents how many elements are gathered per work item lane.

return %loaded : vector<4xi16>
}

// -----
// CHECK-LABEL: func.func @blockstore(
// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr<1>,
// CHECK-SAME: %[[ARG1:.*]]: vector<4xi32>)
func.func @blockstore(%ptr: !llvm.ptr<1>, %value: vector<4xi32>) {
// CHECK: xevm.blockstore %[[ARG0]], %[[ARG1]]
// CHECK-SAME: (!llvm.ptr<1>, vector<4xi32>)
xevm.blockstore %ptr, %value : (!llvm.ptr<1>, vector<4xi32>)
return
}

// -----
// CHECK-LABEL: func.func @mma(
// CHECK-SAME: %[[ARG0:.*]]: vector<8xf32>, %[[ARG1:.*]]: vector<8xi16>, %[[ARG2:.*]]: vector<8xi32>)
Expand Down