diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td index 1f420c13ebae0..59dcbafebc515 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td @@ -527,4 +527,35 @@ def XeGPU_RangeAttr : XeGPUAttr<"Range", "range"> { let genVerifyDecl = 1; } +def XeGPU_MemLayoutAttr : XeGPUAttr<"MemLayout", "mem_layout"> { + let summary = [{Specifies memory layouts with named attributes.}]; + + let description = [{ + This attribute stores a collection of named attributes that describe + memory layout properties such as stride, block, etc. + }]; + + let parameters = (ins "DictionaryAttr": $attrs); + let hasCustomAssemblyFormat = 1; + + + let extraClassDeclaration = [{ + /// Get a specific attribute by name + Attribute getAttr(StringRef name) const { + return getAttrs().get(name); + } + + /// Check if a specific attribute exists + bool hasAttr(StringRef name) const { + return getAttrs().contains(name); + } + + ArrayAttr getStrides() { + return getAttrs().getAs("stride"); + } + + }]; + +} + #endif // MLIR_DIALECT_XEGPU_IR_XEGPUATTRS_TD diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td index 1a6a34c8d775a..d5e2db0f7551d 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td @@ -1101,4 +1101,150 @@ def XeGPU_ConvertLayoutOp: XeGPU_Op<"convert_layout", [Pure, AllTypesMatch<["sou let hasCanonicalizer = 1; } +def isSharedPred : CPred<"isSharedMemory(llvm::cast($_self))">; +class StaticShared1DMemRefOf allowedTypes> : + ConfinedType, [HasStaticShapePred, isSharedPred], + "statically shaped " # MemRefOf.summary # " for shared memory", + "mlir::MemRefType">; + +class SizeInBits : + StrFunc<"llvm::cast($" # name # ".getType()).getNumElements()" + "*llvm::cast($" # name # ".getType()).getElementTypeBitWidth()">; +class AllMemSizesMatch names> : + AllMatchSameOperatorTrait.result, + "size in bits">; + +def XeGPU_CreateMemDescOp: XeGPU_Op<"create_mem_desc", [Pure, + AllMemSizesMatch<["source", "mem_desc"]>]> { + let summary = "Create a memory descriptor."; + let description = [{ + Creates a memory descriptor from a shared local memory (SLM) buffer, and xegpu + specific memory layout. The resulting memory descriptor has to have the same size + as the underlying shared local memory. + + Arguments: + - `source` : a 1D statically shaped memref with element type i8, representing the raw SLM buffer. + Results: + - `mem_desc` : the memory descriptor. + }]; + let arguments = (ins StaticShared1DMemRefOf<[I8]>:$source); + let results = (outs XeGPU_MemDesc:$mem_desc); + let assemblyFormat = "$source prop-dict attr-dict `` `:` type($source) `->` qualified(type($mem_desc))"; +} + +def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>, + AllElementTypesMatch<["mem_desc", "res"]>, + AllRanksMatch<["mem_desc", "res"]>]> { + let arguments = (ins XeGPU_MemDesc:$mem_desc, + Variadic: $offsets, + DenseI64ArrayAttr: $const_offsets, + OptionalAttr:$layout + ); + let results = (outs XeGPU_ValueType:$res); + let assemblyFormat = [{ + $mem_desc `` custom($offsets, $const_offsets) + prop-dict attr-dict `` `:` type(operands) `->` type(results) + }]; + + let description = [{ + This operation reads a block of data from shared local memory (SLM) + using the provided memory descriptor. + + Arguments: + - `mem_desc`: the memory descriptor identifying the SLM region. + - `offsets`: the coordinates within the matrix to read from. + - `layout`: [optional] An attribute for guiding distributions among + subgroups and/or work-items. It currently can accept either + LayoutAttr or SliceAttr. + Results: + - `res`: the matrix elements loaded from SLM. + }]; + + let builders = [ + OpBuilder<(ins "Type":$res, "TypedValue": $mem_desc, + "llvm::ArrayRef": $offsets, "LayoutTrait": $layout)>, + ]; + let extraClassDeclaration = [{ + SmallVector getMixedOffsets() { + return getMixedValues(getConstOffsets(), getOffsets(), getContext()); + } + }]; + + let hasVerifier = 1; +} + +def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>, + AllElementTypesMatch<["mem_desc", "data"]>, + AllRanksMatch<["mem_desc", "data"]>]> { + let arguments = (ins + XeGPU_ValueType:$data, + XeGPU_MemDesc:$mem_desc, + Variadic: $offsets, + DenseI64ArrayAttr: $const_offsets, + OptionalAttr:$layout + ); + let assemblyFormat = [{ $data `,` $mem_desc `` custom($offsets, $const_offsets) + prop-dict attr-dict `` `:` type(operands)}]; + let description = [{ + This operation writes the `data` fragment into the shared local memory region + identified by `mem_desc`. + + Arguments: + - `mem_desc`: the memory descriptor specifying the SLM region. + - `offsets`: the coordinates within the matrix where the data will be written. + - `data`: the values to be stored in the matrix. + - `layout`: [optional] An attribute for guiding distributions among + subgroups and/or work-items. It currently can accept either + LayoutAttr or SliceAttr. + }]; + let builders = [ + OpBuilder<(ins "Value" : $data, "TypedValue": $mem_desc, + "llvm::ArrayRef": $offsets, "LayoutTrait": $layout)>, + ]; + let extraClassDeclaration = [{ + SmallVector getMixedOffsets() { + return getMixedValues(getConstOffsets(), getOffsets(), getContext()); + } + }]; + + let hasVerifier = 1; +} + +def XeGPU_MemDescSubviewOp: XeGPU_Op<"mem_desc_subview", + [Pure, ViewLikeOpInterface, AllElementTypesMatch<["src", "res"]>]> { + let description = [{ + Creates a subview of a memory descriptor. The resulting memory descriptor can have + a lower rank than the source; in this case, the result dimensions correspond to the + higher-order dimensions of the source memory descriptor. + + Arguments: + - `src` : a memory descriptor. + - `offsets` : the coordinates within the matrix the subview will be created from. + + Results: + - `res` : a memory descriptor with smaller size. + + }]; + let arguments = (ins XeGPU_MemDesc:$src, + Variadic:$offsets, + DenseI64ArrayAttr:$const_offsets); + let results = (outs XeGPU_MemDesc:$res); + let assemblyFormat = [{$src `` custom($offsets, $const_offsets) prop-dict + attr-dict `` `:` qualified(type($src)) `->` qualified(type($res))}]; + let builders = [ + OpBuilder<(ins "Type": $res, "Value":$src, "llvm::ArrayRef": $offsets)> + ]; + + let extraClassDeclaration = [{ + mlir::Value getViewSource() { return getSrc(); } + + SmallVector getMixedOffsets() { + return getMixedValues(getConstOffsets(), getOffsets(), getContext()); + } + }]; + + let hasVerifier = 1; +} + + #endif // MLIR_DIALECT_XEGPU_IR_XEGPUOPS_TD diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td index b268cabb5d266..a4411ec8620da 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td @@ -201,4 +201,53 @@ def XeGPU_Nbarrier: XeGPUTypeDef<"Nbarrier", "nbarrier", [], "mlir::Type"> { }]; } +def XeGPU_MemDesc: XeGPUTypeDef<"MemDesc", "mem_desc", [ShapedTypeInterface], "mlir::Type"> { + let summary = "MemDesc describing the data in SLM"; + let description = [{ + MemDesc represents a block of data stored in shared local memory. + By default, unless a layout attribute is provided, the data is stored + contiguously in row-major order within the region. + + Examples: + ```mlir + // A block of data stored in column-major order. + !xegpu.mem_desc<128x128xf16, #xegpu.mem_layout> + + // A block of data stored in a blocked layout. Elements within the same block + // are stored contiguously in memory. Blocks are stored in row-major order. + !xegpu.mem_desc<128x128xf16, #xegpu.mem_layout> + + // A block of data stored in column-major order with blocked layout. + !xegpu.mem_desc<128x128xf16, #xegpu.mem_layout> + ``` + }]; + let parameters = (ins ArrayRefParameter<"int64_t">: $shape, + "mlir::Type": $elementType, + OptionalParameter<"MemLayoutAttr">: $mem_layout); + + let extraClassDeclaration = [{ + bool hasRank() const { return true; } + + MemDescType cloneWith(std::optional> shape, Type elementType) const { + return MemDescType::get(getContext(), shape.value_or(getShape()), elementType, getMemLayout()); + } + + ArrayAttr getStrides() { + auto layout = getMemLayout(); + if (layout && layout.hasAttr("stride")) { + return layout.getStrides(); + } + + // derive and return default strides + SmallVector defaultStrides; + llvm::append_range(defaultStrides, getShape().drop_front()); + llvm::append_values(defaultStrides, 1); + Builder builder(getContext()); + return builder.getI64ArrayAttr(defaultStrides); + } + }]; + + let hasCustomAssemblyFormat = true; +} + #endif // MLIR_DIALECT_XEGPU_IR_XEGPUTYPES_TD diff --git a/mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt index 7c6a4f37db9af..7869a28dfed57 100644 --- a/mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt @@ -17,6 +17,8 @@ add_mlir_dialect_library(MLIRXeGPUDialect MLIRAffineUtils MLIRArithUtils MLIRDialectUtils + MLIRGPUDialect + MLIRXeVMDialect MLIRIR MLIRViewLikeInterface MLIRVectorDialect diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp index d997296a22c20..1b26542ff65a3 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp @@ -427,7 +427,7 @@ RangeAttr::verify(llvm::function_ref emitError, // XeGPU_TensorDescType //===----------------------------------------------------------------------===// -mlir::Type TensorDescType::parse(::mlir::AsmParser &parser) { +mlir::Type TensorDescType::parse(AsmParser &parser) { llvm::SmallVector shape; mlir::Type elementType; mlir::FailureOr encoding; @@ -477,7 +477,7 @@ mlir::Type TensorDescType::parse(::mlir::AsmParser &parser) { layout.value_or(mlir::Attribute())); } -void TensorDescType::print(::mlir::AsmPrinter &printer) const { +void TensorDescType::print(AsmPrinter &printer) const { printer << "<"; auto shape = getShape(); @@ -522,10 +522,10 @@ TensorDescType TensorDescType::get(llvm::ArrayRef shape, return Base::get(context, shape, elementType, attr, layout); } -LogicalResult TensorDescType::verify( - llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, - llvm::ArrayRef shape, mlir::Type elementType, - mlir::Attribute encoding, mlir::Attribute layout) { +LogicalResult +TensorDescType::verify(llvm::function_ref emitError, + llvm::ArrayRef shape, mlir::Type elementType, + mlir::Attribute encoding, mlir::Attribute layout) { size_t rank = shape.size(); if (rank == 0) @@ -591,6 +591,119 @@ LogicalResult TensorDescType::verify( return success(); } +//===----------------------------------------------------------------------===// +// XeGPU_MemDescType +//===----------------------------------------------------------------------===// +mlir::Type MemDescType::parse(AsmParser &parser) { + llvm::SmallVector shape; + mlir::Type elementType; + mlir::FailureOr layout; + + // Parse literal '<' + if (parser.parseLess()) + return {}; + + auto shapeLoc = parser.getCurrentLocation(); + if (mlir::failed(parser.parseDimensionList(shape, false, true))) { + parser.emitError(shapeLoc, "failed to parse parameter 'shape'"); + return {}; + } + + auto elemTypeLoc = parser.getCurrentLocation(); + if (mlir::failed(parser.parseType(elementType))) { + parser.emitError(elemTypeLoc, "failed to parse parameter 'elementType'"); + return {}; + } + + // parse optional attributes + if (mlir::succeeded(parser.parseOptionalComma())) { + MemLayoutAttr attr; + ParseResult res = parser.parseAttribute(attr); + if (mlir::failed(res)) + return {}; + layout = attr; + } + + // Parse literal '>' + if (parser.parseGreater()) + return {}; + + MLIRContext *ctxt = parser.getContext(); + return MemDescType::getChecked( + [&]() { return parser.emitError(parser.getNameLoc()); }, ctxt, shape, + elementType, layout.value_or(MemLayoutAttr())); +} + +void MemDescType::print(AsmPrinter &printer) const { + printer << "<"; + + printer.printDimensionList(getShape()); + printer << 'x'; + printer << getElementType(); + + if (auto layout = getMemLayout()) + printer << ", " << layout; + + printer << ">"; +} + +//===----------------------------------------------------------------------===// +// XeGPU_MemDescType +//===----------------------------------------------------------------------===// + +Attribute MemLayoutAttr::parse(AsmParser &parser, Type type) { + + auto context = parser.getContext(); + llvm::SMLoc loc = parser.getCurrentLocation(); + + llvm::SmallDenseSet seenKeys; + SmallVector attributes; + + auto parseElt = [&]() -> ParseResult { + StringRef nameId; + if (failed(parser.parseKeyword(&nameId))) + return parser.emitError(loc, "expected valid attribute name"); + + if (!seenKeys.insert(nameId).second) + return parser.emitError(loc, "duplicate key '") + << nameId << " in mem layout attribute"; + + if (failed(parser.parseEqual())) + return failure(); + + Attribute attr; + if (failed(parser.parseAttribute(attr))) + return failure(); + attributes.emplace_back(nameId, attr); + return success(); + }; + + // Parse literal '<' + if (parser.parseLess()) + return {}; + + if (failed(parser.parseCommaSeparatedList(parseElt))) + return {}; + + // Parse literal '>' + if (parser.parseGreater()) + return {}; + + return parser.getChecked( + loc, context, DictionaryAttr::get(context, attributes)); +} + +void MemLayoutAttr::print(AsmPrinter &printer) const { + printer << "<"; + ArrayRef attrs = getAttrs().getValue(); + for (size_t i = 0; i < attrs.size(); i++) { + printer << attrs[i].getName().str() << " = " << attrs[i].getValue(); + if (i < attrs.size() - 1) + printer << ", "; + } + printer << ">"; +} + } // namespace xegpu } // namespace mlir diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index 2cd086feb5deb..1caa37d8353bc 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -7,6 +7,8 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Arith/Utils/Utils.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/XeVMDialect.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/Dialect/XeGPU/IR/XeGPU.h" @@ -21,6 +23,17 @@ namespace mlir { namespace xegpu { +bool isSharedMemory(const MemRefType &memrefTy) { + Attribute attr = memrefTy.getMemorySpace(); + if (auto intAttr = llvm::dyn_cast(attr)) + return intAttr.getInt() == 3; + if (auto memrefSpace = llvm::dyn_cast(attr)) + return memrefSpace.getValue() == MemorySpace::SLM; + if (auto xevmSpace = llvm::dyn_cast(attr)) + return xevmSpace.getValue() == xevm::AddrSpace::SHARED; + return gpu::GPUDialect::isWorkgroupMemoryAddressSpace(attr); +} + template static std::string makeString(T array, bool breakline = false) { std::string buf; @@ -925,6 +938,89 @@ void ConvertLayoutOp::getCanonicalizationPatterns(RewritePatternSet &patterns, patterns.add(context); } +//===----------------------------------------------------------------------===// +// XeGPU_LoadMatrixOp +//===----------------------------------------------------------------------===// +void LoadMatrixOp::build(OpBuilder &builder, OperationState &state, Type res, + TypedValue matrixDesc, + llvm::ArrayRef offsets, + LayoutTrait layout) { + llvm::SmallVector dynamicOffsets; + llvm::SmallVector staticOffsets; + dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); + auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets); + build(builder, state, res, matrixDesc, dynamicOffsets, staticOffsetsAttr, + layout); +} + +LogicalResult LoadMatrixOp::verify() { + ArrayRef valueShape = getRes().getType().getShape(); + ArrayRef mdescShape = getMemDesc().getType().getShape(); + if (llvm::any_of(llvm::zip_equal(valueShape, mdescShape), + [](auto p) { return std::get<0>(p) > std::get<1>(p); })) + return emitOpError("result shape must not exceed matrix desc shape."); + return success(); +} + +//===----------------------------------------------------------------------===// +// XeGPU_StoreMatrixOp +//===----------------------------------------------------------------------===// +void StoreMatrixOp::build(OpBuilder &builder, OperationState &state, Value data, + TypedValue matrixDesc, + llvm::ArrayRef offsets, + LayoutTrait layout) { + llvm::SmallVector dynamicOffsets; + llvm::SmallVector staticOffsets; + dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); + auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets); + build(builder, state, data, matrixDesc, dynamicOffsets, staticOffsetsAttr, + layout); +} + +LogicalResult StoreMatrixOp::verify() { + ArrayRef dataShape = getData().getType().getShape(); + ArrayRef mdescShape = getMemDesc().getType().getShape(); + if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape), + [](auto p) { return std::get<0>(p) > std::get<1>(p); })) + return emitOpError("data shape must not exceed matrix desc shape."); + + return success(); +} + +//===----------------------------------------------------------------------===// +// XeGPU_MemDescSubviewOp +//===----------------------------------------------------------------------===// + +void MemDescSubviewOp::build(OpBuilder &builder, OperationState &state, + Type resTy, Value src, + llvm::ArrayRef offsets) { + llvm::SmallVector dynamicOffsets; + llvm::SmallVector staticOffsets; + dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); + auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets); + build(builder, state, resTy, src, dynamicOffsets, staticOffsetsAttr); +} + +LogicalResult MemDescSubviewOp::verify() { + MemDescType srcTy = getSrc().getType(); + MemDescType resTy = getRes().getType(); + ArrayRef srcShape = srcTy.getShape(); + ArrayRef resShape = resTy.getShape(); + + if (srcTy.getRank() < resTy.getRank()) + return emitOpError("result rank must not exceed source rank."); + + if (llvm::any_of( + llvm::zip_equal(resShape, srcShape.take_back(resShape.size())), + [](auto p) { return std::get<0>(p) > std::get<1>(p); })) + return emitOpError("result shape must not exceed source shape."); + + if (srcTy.getStrides() != resTy.getStrides()) + return emitOpError("result must inherit the source strides."); + + return success(); +} + } // namespace xegpu } // namespace mlir diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index 4a5525c8abb30..5d5d698c88cba 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -475,8 +475,8 @@ struct WgToSgElementwiseOp : public ConversionPattern { // is lowered to: // #a = #xegpu.layout // #b = #xegpu.layout -// store_matrix %1, %slm <{layout_input_0 = #a}> : vector<32x16>, matrix_desc<32x64xf32> -// %d = load_matrix %slm <{layout_result_0 = #a}> : matrix_desc<32x64xf32> -> vector<16x32xf32> +// store_matrix %1, %slm <{layout_input_0 = #a}> : vector<32x16>, mem_desc<32x64xf32> +// %d = load_matrix %slm <{layout_result_0 = #a}> : mem_desc<32x64xf32> -> vector<16x32xf32> // xegpu.convert_layout %d <{input_layout = #a, target_layout = #b}> : vector<16x32xf32> // clang-format on struct WgToSgConvertLayoutOp diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir index 44e15dd7cbb38..e8ef57ca192a9 100644 --- a/mlir/test/Dialect/XeGPU/invalid.mlir +++ b/mlir/test/Dialect/XeGPU/invalid.mlir @@ -762,3 +762,75 @@ func.func @slice_attr_repeat_dim() { return } +// ----- +func.func @create_mem_desc_non_slm() { + %m = memref.alloca() {alignment = 1024} : memref<2048xi8, 1> + // expected-error@+1 {{operand #0 must be statically shaped memref of 8-bit signless integer values for shared memory}} + %mem_desc = xegpu.create_mem_desc %m : memref<2048xi8, 1> -> !xegpu.mem_desc<16x64xf16> + return +} + +// ----- +func.func @create_mem_desc_mismatch_sizes() { + %m = memref.alloca() {alignment = 1024} : memref<2048xi8, 3> + // expected-error@+1 {{failed to verify that all of {source, mem_desc} have same size in bits}} + %mem_desc = xegpu.create_mem_desc %m : memref<2048xi8, 3> -> !xegpu.mem_desc<16x32xf16> + return +} + +// ----- +func.func @load_mem_desc_mismatch_element_type(%arg0: !xegpu.mem_desc<16x64xf16>) { + // expected-error@+1 {{failed to verify that all of {mem_desc, res} have same element type}} + %data = xegpu.load_matrix %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> vector<8x16xf32> + return +} + +// ----- +func.func @load_mem_desc_invalid_result_size(%arg0: !xegpu.mem_desc<16x64xf16>) { + // expected-error@+1 {{result shape must not exceed matrix desc shape}} + %data = xegpu.load_matrix %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> vector<32x16xf16> + return +} + +// ----- +func.func @store_mem_desc_mismatch_element_type(%arg0: !xegpu.mem_desc<16x64xf16>, %arg1: vector<16x16xf32>) { + // expected-error@+1 {{failed to verify that all of {mem_desc, data} have same element type}} + xegpu.store_matrix %arg1, %arg0[8, 8] : vector<16x16xf32>, !xegpu.mem_desc<16x64xf16> + return +} + +// ----- +func.func @store_mem_desc_invalid_data_size(%arg0: !xegpu.mem_desc<16x64xf16>, %arg1: vector<32x32xf16>) { + // expected-error@+1 {{data shape must not exceed matrix desc shape}} + xegpu.store_matrix %arg1, %arg0[8, 8] : vector<32x32xf16>, !xegpu.mem_desc<16x64xf16> + return +} + +// ----- +func.func @mem_desc_subview_size_mismatch(%arg0: !xegpu.mem_desc<16x64xf16>) { + // expected-error@+1 {{result shape must not exceed source shape}} + %data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<32x16xf16> + return +} + +// ----- +func.func @mem_desc_subview_layout_mismatch(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout>) { + // expected-error@+1 {{result must inherit the source strides}} + %data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> -> !xegpu.mem_desc<8x16xf16> + return +} + +// ----- +func.func @mem_desc_subview_element_type_mismatch(%arg0: !xegpu.mem_desc<16x64xf16>) { + // expected-error@+1 {{failed to verify that all of {src, res} have same element type}} + %data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<8x16xf32, #xegpu.mem_layout> + return +} + +// ----- +func.func @mem_desc_subview_rank_mismatch(%arg0: !xegpu.mem_desc<16x64xf16>) { + // expected-error@+1 {{result rank must not exceed source rank}} + %data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<4x8x16xf16> + return +} + diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir index 67c00f5a9cc2f..35342eca1354c 100644 --- a/mlir/test/Dialect/XeGPU/ops.mlir +++ b/mlir/test/Dialect/XeGPU/ops.mlir @@ -751,4 +751,72 @@ gpu.func @fence() { gpu.return } +// CHECK-LABEL: gpu.func @create_mem_desc({{.*}}) { +gpu.func @create_mem_desc() { + //CHECK: [[alloc:%.+]] = memref.alloca() {alignment = 1024 : i64} : memref<2048xi8, 3> + //CHECK: [[mdesc:%.+]] = xegpu.create_mem_desc [[alloc]] : memref<2048xi8, 3> -> !xegpu.mem_desc<16x64xf16> + %m = memref.alloca() {alignment = 1024} : memref<2048xi8, 3> + %mem_desc = xegpu.create_mem_desc %m : memref<2048xi8, 3> -> !xegpu.mem_desc<16x64xf16> + gpu.return +} + +// CHECK-LABEL: gpu.func @create_mem_desc_with_stride({{.*}}) { +gpu.func @create_mem_desc_with_stride() { + //CHECK: [[alloc:%.+]] = memref.alloca() {alignment = 1024 : i64} : memref<2048xi8, 3> + //CHECK: [[mdesc:%.+]] = xegpu.create_mem_desc [[alloc]] : memref<2048xi8, 3> -> !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> + %m = memref.alloca() {alignment = 1024} : memref<2048xi8, 3> + %mem_desc = xegpu.create_mem_desc %m : memref<2048xi8, 3> -> !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> + gpu.return +} + +// CHECK: gpu.func @load_mem_desc([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>) +gpu.func @load_mem_desc(%arg0: !xegpu.mem_desc<16x64xf16>) { + // CHECK: xegpu.load_matrix [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16> -> vector<8x16xf16> + %data = xegpu.load_matrix %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> vector<8x16xf16> + gpu.return +} + +// CHECK: gpu.func @load_mem_desc_with_stride(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout>) +gpu.func @load_mem_desc_with_stride(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout>) { + // CHECK: xegpu.load_matrix [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> -> vector<8x16xf16> + %data = xegpu.load_matrix %arg0[8, 8]: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> -> vector<8x16xf16> + gpu.return +} + + +// CHECK: gpu.func @store_mem_desc([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>, [[ARG1:%.+]]: vector<16x16xf16>) +gpu.func @store_mem_desc(%arg0: !xegpu.mem_desc<16x64xf16>, %arg1: vector<16x16xf16>) { + // CHECK: xegpu.store_matrix [[ARG1]], [[ARG0]][8, 8] : vector<16x16xf16>, !xegpu.mem_desc<16x64xf16> + xegpu.store_matrix %arg1, %arg0[8, 8]: vector<16x16xf16>, !xegpu.mem_desc<16x64xf16> + gpu.return +} + +// CHECK: gpu.func @store_mem_desc_with_stride([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout>, [[ARG1:%.+]]: vector<16x16xf16>) +gpu.func @store_mem_desc_with_stride(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout>, %arg1: vector<16x16xf16>) { + // CHECK: xegpu.store_matrix [[ARG1]], [[ARG0]][0, 8] : vector<16x16xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> + xegpu.store_matrix %arg1, %arg0[0, 8]: vector<16x16xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> + gpu.return +} + +// CHECK: gpu.func @mem_desc_subview([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>) +gpu.func @mem_desc_subview(%arg0: !xegpu.mem_desc<16x64xf16>) { + //CHECK: xegpu.mem_desc_subview [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<8x16xf16, #xegpu.mem_layout> + %data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<8x16xf16, #xegpu.mem_layout> + gpu.return +} + +// CHECK: gpu.func @mem_desc_subview_lower_rank([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>) +gpu.func @mem_desc_subview_lower_rank(%arg0: !xegpu.mem_desc<16x64xf16>) { + //CHECK: xegpu.mem_desc_subview [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<16xf16, #xegpu.mem_layout> + %data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<16xf16, #xegpu.mem_layout> + gpu.return +} + +// CHECK: gpu.func @mem_desc_subview_with_stride([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout>) +gpu.func @mem_desc_subview_with_stride(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout>) { + //CHECK: xegpu.mem_desc_subview [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> -> !xegpu.mem_desc<8x16xf16, #xegpu.mem_layout> + %data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> -> !xegpu.mem_desc<8x16xf16, #xegpu.mem_layout> + gpu.return +} + }