Skip to content

[mlir][xegpu] Add definitons of MatrixDescType and related ops. #153273

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
137 changes: 137 additions & 0 deletions mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1101,4 +1101,141 @@ def XeGPU_ConvertLayoutOp: XeGPU_Op<"convert_layout", [Pure, AllTypesMatch<["sou
let hasCanonicalizer = 1;
}

def isSharedPred : CPred<"isSharedMemory(llvm::cast<mlir::MemRefType>($_self))">;
class StaticShared1DMemRefOf<list<Type> allowedTypes> :
ConfinedType<MemRefRankOf<allowedTypes, [1]>, [HasStaticShapePred, isSharedPred],
"statically shaped " # MemRefOf<allowedTypes>.summary # " for shared memory",
"mlir::MemRefType">;

class SizeInBits<string name> :
StrFunc<"llvm::cast<mlir::ShapedType>($" # name # ".getType()).getNumElements()"
"*llvm::cast<mlir::ShapedType>($" # name # ".getType()).getElementTypeBitWidth()">;
class AllMemSizesMatch<list<string> names> :
AllMatchSameOperatorTrait<names, SizeInBits<"_self">.result,
"size in bits">;

def XeGPU_CreateMatrixDescOp: XeGPU_Op<"create_matrix_desc", [Pure,
AllMemSizesMatch<["source", "matrix_desc"]>]> {
let summary = "Create a matrix descriptor.";
let description = [{
Creates a matrix descriptor from a shared local memory (SLM) buffer.
The resulting matrix descriptor has to have the same size as the underlying
shared local memory.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shared local memory => memory. The memory descriptor itself doesn't have to associated with share local memory.


Arguments:
- `source` : a 1D statically shaped memref with element type i8, representing the raw SLM buffer.
Results:
- `matrix_desc` : the matrix descriptor.
}];
let arguments = (ins StaticShared1DMemRefOf<[I8]>:$source);
let results = (outs XeGPU_MatrixDesc:$matrix_desc);
let assemblyFormat = "$source prop-dict attr-dict `` `:` type($source) `->` qualified(type($matrix_desc))";
}

def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
AllElementTypesMatch<["matrix_desc", "res"]>,
AllRanksMatch<["matrix_desc", "res"]>]> {
let arguments = (ins XeGPU_MatrixDesc:$matrix_desc,
Variadic<Index>: $offsets,
DenseI64ArrayAttr: $const_offsets,
OptionalAttr<LayoutTrait>:$layout
);
let results = (outs XeGPU_ValueType:$res);
let assemblyFormat = [{
$matrix_desc `` custom<DynamicIndexList>($offsets, $const_offsets)
prop-dict attr-dict `` `:` type(operands) `->` type(results)
}];

let description = [{
This operation reads a block of data from shared local memory (SLM)
using the provided matrix descriptor.

Arguments:
- `matrix_desc`: the matrix descriptor identifying the SLM region.
- `offsets`: the coordinates within the matrix to read from.
Results:
- `res`: the matrix elements loaded from SLM.
}];

let builders = [
OpBuilder<(ins "Type":$res, "TypedValue<MatrixDescType>": $matrix_desc,
"llvm::ArrayRef<OpFoldResult>": $offsets, "LayoutTrait": $layout)>,
];
let extraClassDeclaration = [{
SmallVector<OpFoldResult> getMixedOffsets() {
return getMixedValues(getConstOffsets(), getOffsets(), getContext());
}
}];

let hasVerifier = 1;
}

def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>,
AllElementTypesMatch<["matrix_desc", "data"]>,
AllRanksMatch<["matrix_desc", "data"]>]> {
let arguments = (ins
XeGPU_MatrixDesc:$matrix_desc,
Variadic<Index>: $offsets,
DenseI64ArrayAttr: $const_offsets,
XeGPU_ValueType:$data,
OptionalAttr<LayoutTrait>:$layout
);
let assemblyFormat = [{
$matrix_desc `` custom<DynamicIndexList>($offsets, $const_offsets) `,` $data
prop-dict attr-dict `:` type(operands)
}];
let description = [{
This operation writes the `data` fragment into the shared local memory region
identified by `matrix_desc`.

Arguments:
- `matrix_desc`: the matrix descriptor specifying the SLM region.
- `offsets`: the coordinates within the matrix where the data will be written.
- `data`: the values to be stored in the matrix.
}];
let builders = [
OpBuilder<(ins "TypedValue<MatrixDescType>": $matrix_desc, "llvm::ArrayRef<OpFoldResult>": $offsets,
"Value" : $data, "LayoutTrait": $layout)>,
];
let extraClassDeclaration = [{
SmallVector<OpFoldResult> getMixedOffsets() {
return getMixedValues(getConstOffsets(), getOffsets(), getContext());
}
}];

let hasVerifier = 1;
}

def XeGPU_MatrixDescSubviewOp: XeGPU_Op<"matrix_desc_subview", [Pure, ViewLikeOpInterface,
AllElementTypesMatch<["src", "res"]>,
AllRanksMatch<["src", "res"]>]> {
let description = [{
Create a subview of a matrix descriptor.
Results:
- `src` : a matrix descriptor.
- `offsets` : the coordinates within the matrix the subview will be created from.
}];
let arguments = (ins XeGPU_MatrixDesc:$src,
Variadic<Index>:$offsets,
DenseI64ArrayAttr:$const_offsets,
OptionalAttr<LayoutTrait>: $layout);
let results = (outs XeGPU_MatrixDesc:$res);
let assemblyFormat = [{$src `` custom<DynamicIndexList>($offsets, $const_offsets) prop-dict
attr-dict `` `:` qualified(type($src)) `->` qualified(type($res))}];
let builders = [
OpBuilder<(ins "Type": $res, "Value":$src, "llvm::ArrayRef<OpFoldResult>": $offsets, "LayoutTrait": $layout)>
];

let extraClassDeclaration = [{
mlir::Value getViewSource() { return getSrc(); }

SmallVector<OpFoldResult> getMixedOffsets() {
return getMixedValues(getConstOffsets(), getOffsets(), getContext());
}
}];

let hasVerifier = 1;
}


#endif // MLIR_DIALECT_XEGPU_IR_XEGPUOPS_TD
22 changes: 22 additions & 0 deletions mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -201,4 +201,26 @@ def XeGPU_Nbarrier: XeGPUTypeDef<"Nbarrier", "nbarrier", [], "mlir::Type"> {
}];
}

def XeGPU_MatrixDesc: XeGPUTypeDef<"MatrixDesc", "matrix_desc", [ShapedTypeInterface], "mlir::Type"> {
let summary = "MatrixDesc describing the data in SLM";
let description = [{
MatrixDesc represents a block of data stored in shared local memory.
By default, unless a layout attribute is provided, the data is stored
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this layout?
I assume this is not distribution layout as you said it is not part of matrix desc type.

Copy link
Contributor

@silee2 silee2 Aug 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the tests, there is usage like
!xegpu.matrix_desc<16x64xf16, strided<[1, 16]>>
Is strided<[1,16]> the layout attribute?
Is that attribute type not some that can be represented by
MemRefLayoutAttrInterface used in MemRefType ?
https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/IR/BuiltinTypes.td#L796-#L801
Looks very similar to MemRefType. For example,
memref<12x4xf32, strided<[4, 1], offset: 5>>

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please refer to intel/mlir-extensions#1092 for the motivation and explanation of the slm memory layout of matrix descriptor.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 to the questions here. It just reads as a MemRef layout which is fine but could use explicit clarification as layout term becomes quite overloaded within the dialect now.

Could you add more description here? Or at least an example snippet.

contiguously in row-major order within the region.
}];
let parameters = (ins ArrayRefParameter<"int64_t">: $shape,
"mlir::Type": $elementType,
OptionalParameter<"mlir::Attribute">: $layout);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consider using mem_layout instead of layout, to differentiate with XeGPU.layout which describes the mapping between sg/lane ids to the data.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed


let extraClassDeclaration = [{
bool hasRank() const { return true; }

MatrixDescType cloneWith(std::optional<llvm::ArrayRef<int64_t>> shape, Type elementType) const {
return MatrixDescType::get(getContext(), shape.value_or(getShape()), elementType, getLayout());
}
}];

let hasCustomAssemblyFormat = true;
}

#endif // MLIR_DIALECT_XEGPU_IR_XEGPUTYPES_TD
1 change: 1 addition & 0 deletions mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ add_mlir_dialect_library(MLIRXeGPUDialect
MLIRAffineUtils
MLIRArithUtils
MLIRDialectUtils
MLIRGPUDialect
MLIRIR
MLIRViewLikeInterface
MLIRVectorDialect
Expand Down
56 changes: 56 additions & 0 deletions mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -591,6 +591,62 @@ LogicalResult TensorDescType::verify(
return success();
}

//===----------------------------------------------------------------------===//
// XeGPU_MatrixDescType
//===----------------------------------------------------------------------===//
mlir::Type MatrixDescType::parse(::mlir::AsmParser &parser) {
llvm::SmallVector<int64_t> shape;
mlir::Type elementType;
mlir::FailureOr<mlir::Attribute> layout;

// Parse literal '<'
if (parser.parseLess())
return {};

auto shapeLoc = parser.getCurrentLocation();
if (mlir::failed(parser.parseDimensionList(shape, false, true))) {
parser.emitError(shapeLoc, "failed to parse parameter 'shape'");
return {};
}

auto elemTypeLoc = parser.getCurrentLocation();
if (mlir::failed(parser.parseType(elementType))) {
parser.emitError(elemTypeLoc, "failed to parse parameter 'elementType'");
return {};
}

// parse optional attributes
if (mlir::succeeded(parser.parseOptionalComma())) {
mlir::Attribute attr;
ParseResult res = parser.parseAttribute(attr);
if (mlir::failed(res))
return {};
layout = attr;
}

// Parse literal '>'
if (parser.parseGreater())
return {};

MLIRContext *ctxt = parser.getContext();
return MatrixDescType::getChecked(
[&]() { return parser.emitError(parser.getNameLoc()); }, ctxt, shape,
elementType, layout.value_or(mlir::Attribute()));
}

void MatrixDescType::print(::mlir::AsmPrinter &printer) const {
printer << "<";

printer.printDimensionList(getShape());
printer << 'x';
printer << getElementType();

if (auto layout = getLayout())
printer << ", " << layout;

printer << ">";
}

} // namespace xegpu
} // namespace mlir

Expand Down
87 changes: 87 additions & 0 deletions mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
Expand All @@ -21,6 +22,15 @@
namespace mlir {
namespace xegpu {

bool isSharedMemory(const MemRefType &memrefTy) {
Attribute attr = memrefTy.getMemorySpace();
if (auto intAttr = llvm::dyn_cast<IntegerAttr>(attr))
return intAttr.getInt() == 3;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

static_cast<int>(xevm::AddrSpace::SHARED) seems more appropriate.
https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td#L329

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added support for this.

if (auto memrefSpace = llvm::dyn_cast<MemorySpaceAttr>(attr))
return memrefSpace.getValue() == MemorySpace::SLM;
return gpu::GPUDialect::isWorkgroupMemoryAddressSpace(attr);
}

template <typename T>
static std::string makeString(T array, bool breakline = false) {
std::string buf;
Expand Down Expand Up @@ -925,6 +935,83 @@ void ConvertLayoutOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
patterns.add<FoldConvertLayoutOp>(context);
}

//===----------------------------------------------------------------------===//
// XeGPU_LoadMatrixOp
//===----------------------------------------------------------------------===//
void LoadMatrixOp::build(OpBuilder &builder, OperationState &state, Type res,
TypedValue<MatrixDescType> matrixDesc,
llvm::ArrayRef<OpFoldResult> offsets,
LayoutTrait layout) {
llvm::SmallVector<Value> dynamicOffsets;
llvm::SmallVector<int64_t> staticOffsets;
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
build(builder, state, res, matrixDesc, dynamicOffsets, staticOffsetsAttr,
layout);
}

LogicalResult LoadMatrixOp::verify() {
ArrayRef<int64_t> valueShape = getRes().getType().getShape();
ArrayRef<int64_t> mdescShape = getMatrixDesc().getType().getShape();
if (llvm::any_of(llvm::zip_equal(valueShape, mdescShape),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does AllShapesMatch in .td definition not suit this purpose?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Per latest definition, it can load or store a smaller shape of data from a bigger MatrixDesc. So AllShapesMatch doesn't fit here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, sorry, missed the > in lambda.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no worries

[](auto p) { return std::get<0>(p) > std::get<1>(p); }))
return emitOpError("result shape must not exceed matrix desc shape.");
return success();
}

//===----------------------------------------------------------------------===//
// XeGPU_StoreMatrixOp
//===----------------------------------------------------------------------===//
void StoreMatrixOp::build(OpBuilder &builder, OperationState &state,
TypedValue<MatrixDescType> matrixDesc,
llvm::ArrayRef<OpFoldResult> offsets, Value data,
LayoutTrait layout) {
llvm::SmallVector<Value> dynamicOffsets;
llvm::SmallVector<int64_t> staticOffsets;
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
build(builder, state, matrixDesc, dynamicOffsets, staticOffsetsAttr, data,
layout);
}

LogicalResult StoreMatrixOp::verify() {
ArrayRef<int64_t> dataShape = getData().getType().getShape();
ArrayRef<int64_t> mdescShape = getMatrixDesc().getType().getShape();
if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape),
[](auto p) { return std::get<0>(p) > std::get<1>(p); }))
return emitOpError("data shape must not exceed matrix desc shape.");

return success();
}

//===----------------------------------------------------------------------===//
// XeGPU_MatrixDescSubviewOp
//===----------------------------------------------------------------------===//

void MatrixDescSubviewOp::build(OpBuilder &builder, OperationState &state,
Type resTy, Value src,
llvm::ArrayRef<OpFoldResult> offsets,
LayoutTrait layout) {
llvm::SmallVector<Value> dynamicOffsets;
llvm::SmallVector<int64_t> staticOffsets;
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
build(builder, state, resTy, src, dynamicOffsets, staticOffsetsAttr, layout);
}

LogicalResult MatrixDescSubviewOp::verify() {
ArrayRef<int64_t> srcShape = getSrc().getType().getShape();
ArrayRef<int64_t> resShape = getRes().getType().getShape();
if (llvm::any_of(llvm::zip_equal(resShape, srcShape),
[](auto p) { return std::get<0>(p) > std::get<1>(p); }))
return emitOpError("result shape must not exceed source shape.");

if (getSrc().getType().getLayout() != getRes().getType().getLayout())
return emitOpError("result must inherit the source layout.");

return success();
}

} // namespace xegpu
} // namespace mlir

Expand Down
Loading