diff --git a/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td index f457f47d56219..514b01a69fb9b 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td @@ -69,6 +69,7 @@ class XeVM_Op traits = []> } def XeVM_ElemType : AnyTypeOf<[AnyI8, AnyI16, AnyI32, F32, TF32, F16, BF16]>; +def XeVM_1DBlockElemType : AnyTypeOf<[I8, I16, I32, I64]>; //===----------------------------------------------------------------------===// // XeVM Load Cache Control @@ -187,6 +188,81 @@ def XeVM_StoreCacheControlAttr let assemblyFormat = "`<` $value `>`"; } +def XeVM_BlockLoadOp + : XeVM_Op<"blockload">, + Results<( + outs FixedVectorOfRankAndType<[1], [XeVM_1DBlockElemType]>:$res)>, + Arguments<(ins Arg:$ptr, + OptionalAttr:$cache_control)> { + let summary = "subgroup block load"; + let description = [{ + Reads one or more components of Result data for each invocation + in the subgroup from the specified `ptr` as a block operation. + The data is read strided, so the first value read is: + ``` + ptr[ SubgroupLocalInvocationId ] + ``` + and the second value read is: + ``` + ptr[ SubgroupLocalInvocationId + SubgroupMaxSize ] + ``` + Result type may be a scalar or vector type of scalar element type. + + The parameters are: + * `ptr` - the base address to load from. Must be uniform across subgroup. + * `cache_control` - an enumerator that sets the cache behaviour + + Example: + ```mlir + %loaded_a = xevm.blockload %src, + <{cache_control=#xevm.load_cache_control}> + : (!llvm.ptr<1>) -> vector<4xi16> + ``` + }]; + let assemblyFormat = [{ + operands prop-dict attr-dict `:` functional-type(operands, results) + }]; + let hasVerifier = 1; +} + +def XeVM_BlockStoreOp + : XeVM_Op<"blockstore">, + Arguments<(ins Arg:$ptr, + FixedVectorOfRankAndType<[1], [XeVM_1DBlockElemType]>:$val, + OptionalAttr:$cache_control)> { + let summary = "subgroup block store"; + let description = [{ + Writes one or more components of `val` for each invocation + in the subgroup to the specified `ptr` as a block operation. + The data is written strided, so the first value is written to: + ``` + ptr[ SubgroupLocalInvocationId ] + ``` + and the second value is written to: + ``` + ptr[ SubgroupLocalInvocationId + SubgroupMaxSize ] + ``` + `val` type may be a scalar or vector type of scalar element type. + + The parameters are: + * `ptr` - the base address to store to. Must be uniform across subgroup. + * `val` - the value to store + * `cache_control` - an enumerator that sets the cache behaviour + + Example: + ```mlir + xevm.blockstore %ptr, %val + <{cache_control=#xevm.store_cache_control}> + : (!llvm.ptr<1>, vector<4xi16>) + ``` + }]; + + let assemblyFormat = [{ + operands prop-dict attr-dict `:` `(` type(operands) `)` + }]; + let hasVerifier = 1; +} + def XeVM_BlockLoad2dOp : XeVM_Op<"blockload2d">, Results<(outs FixedVectorOfRankAndType<[1], [XeVM_ElemType]>:$res)>, diff --git a/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp index 24e6a9c284e26..8295492ad73a8 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp @@ -9,6 +9,7 @@ #include "mlir/Dialect/GPU/IR/CompilationInterfaces.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/DialectImplementation.h" +#include "llvm/ADT/SmallSet.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/FileSystem.h" #include "llvm/Support/MathExtras.h" @@ -306,6 +307,36 @@ LogicalResult BlockPrefetch2dOp::verify() { return success(); } +template ::value>> +LogicalResult verify1DBlockArg(OpType op) { + VectorType vTy; + if constexpr (std::is_same_v) + vTy = op.getResult().getType(); + else + vTy = op.getVal().getType(); + int elemTySize = vTy.getElementType().getIntOrFloatBitWidth() / 8; + if (elemTySize == 1) { + llvm::SmallSet validSizes{1, 2, 4, 8, 16}; + if (validSizes.contains(vTy.getNumElements())) + return success(); + else + return op.emitOpError( + "vector size must be 1, 2, 4, 8 or 16 for 8-bit element type"); + } else { + llvm::SmallSet validSizes{1, 2, 4, 8}; + if (validSizes.contains(vTy.getNumElements())) + return success(); + else + return op.emitOpError( + "vector size must be 1, 2, 4 or 8 for element type > 8 bits"); + } +} + +LogicalResult BlockLoadOp::verify() { return verify1DBlockArg(*this); } + +LogicalResult BlockStoreOp::verify() { return verify1DBlockArg(*this); } + LogicalResult MMAOp::verify() { if (getC()) { if (getResult().getType() != getC().getType()) diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir index 4394786db5a5d..749fb634dba76 100644 --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -1972,6 +1972,20 @@ llvm.func @invalid_xevm_prefetch(%arg0: !llvm.ptr) { llvm.return } +// ----- +llvm.func @invalid_xevm_blockload(%arg0: !llvm.ptr<1>) { + // expected-error@+1 {{op vector size must be 1, 2, 4 or 8 for element type > 8 bits}} + %0 = xevm.blockload %arg0 : (!llvm.ptr<1>) -> vector<3xi16> + llvm.return +} + +// ----- +llvm.func @invalid_xevm_blockstore(%arg0: !llvm.ptr<1>, %arg1: vector<5xi8>) { + // expected-error@+1 {{op vector size must be 1, 2, 4, 8 or 16 for 8-bit element type}} + xevm.blockstore %arg0, %arg1 : (!llvm.ptr<1>, vector<5xi8>) + llvm.return +} + // ----- llvm.func @invalid_xevm_mma(%loaded_c_casted: vector<4xf32>, %loaded_a: vector<8xi16>, %loaded_b_casted: vector<8xi32>) -> vector<8xf32> { diff --git a/mlir/test/Dialect/LLVMIR/xevm.mlir b/mlir/test/Dialect/LLVMIR/xevm.mlir index 3dd5f872f898c..bb1f650a1cd12 100644 --- a/mlir/test/Dialect/LLVMIR/xevm.mlir +++ b/mlir/test/Dialect/LLVMIR/xevm.mlir @@ -58,6 +58,29 @@ func.func @blockprefetch2d(%ptr: !llvm.ptr<1>, %base_width: i32, %base_height: i return } +// ----- +// CHECK-LABEL: func.func @blockload( +// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr<1>) +func.func @blockload(%ptr: !llvm.ptr<1>) -> vector<4xi16> { + // CHECK: %[[VAR0:.*]] = xevm.blockload %[[ARG0]] + // CHECK-SAME: cache_control = #xevm.load_cache_control + // CHECK-SAME: (!llvm.ptr<1>) -> vector<4xi16> + %loaded = xevm.blockload %ptr <{cache_control=#xevm.load_cache_control}> + : (!llvm.ptr<1>) -> vector<4xi16> + return %loaded : vector<4xi16> +} + +// ----- +// CHECK-LABEL: func.func @blockstore( +// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr<1>, +// CHECK-SAME: %[[ARG1:.*]]: vector<4xi32>) +func.func @blockstore(%ptr: !llvm.ptr<1>, %value: vector<4xi32>) { + // CHECK: xevm.blockstore %[[ARG0]], %[[ARG1]] + // CHECK-SAME: (!llvm.ptr<1>, vector<4xi32>) + xevm.blockstore %ptr, %value : (!llvm.ptr<1>, vector<4xi32>) + return +} + // ----- // CHECK-LABEL: func.func @mma( // CHECK-SAME: %[[ARG0:.*]]: vector<8xf32>, %[[ARG1:.*]]: vector<8xi16>, %[[ARG2:.*]]: vector<8xi32>)