Skip to content

Commit 9c4e571

Browse files
authored
[mlir][xegpu] Add definitions of MemDescType and related ops. (#153273)
1 parent df0e9f3 commit 9c4e571

File tree

9 files changed

+612
-8
lines changed

9 files changed

+612
-8
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -527,4 +527,34 @@ def XeGPU_RangeAttr : XeGPUAttr<"Range", "range"> {
527527
let genVerifyDecl = 1;
528528
}
529529

530+
def XeGPU_MemLayoutAttr : XeGPUAttr<"MemLayout", "mem_layout"> {
531+
let summary = [{Specifies memory layouts with named attributes.}];
532+
533+
let description = [{
534+
This attribute stores a collection of named attributes that describe
535+
memory layout properties such as stride, block, etc.
536+
}];
537+
538+
let parameters = (ins "DictionaryAttr": $attrs);
539+
let hasCustomAssemblyFormat = 1;
540+
541+
let extraClassDeclaration = [{
542+
/// Get a specific attribute by name
543+
Attribute getAttr(StringRef name) const {
544+
return getAttrs().get(name);
545+
}
546+
547+
/// Check if a specific attribute exists
548+
bool hasAttr(StringRef name) const {
549+
return getAttrs().contains(name);
550+
}
551+
552+
ArrayAttr getStrides() {
553+
return getAttrs().getAs<ArrayAttr>("stride");
554+
}
555+
556+
}];
557+
558+
}
559+
530560
#endif // MLIR_DIALECT_XEGPU_IR_XEGPUATTRS_TD

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1097,4 +1097,152 @@ def XeGPU_ConvertLayoutOp: XeGPU_Op<"convert_layout", [Pure, AllTypesMatch<["sou
10971097
let hasCanonicalizer = 1;
10981098
}
10991099

1100+
def isSharedPred : CPred<"isSharedMemory(llvm::cast<mlir::MemRefType>($_self))">;
1101+
class StaticShared1DMemRefOf<list<Type> allowedTypes> :
1102+
ConfinedType<MemRefRankOf<allowedTypes, [1]>, [HasStaticShapePred, isSharedPred],
1103+
"statically shaped " # MemRefOf<allowedTypes>.summary # " for shared memory",
1104+
"mlir::MemRefType">;
1105+
1106+
class SizeInBits<string name> :
1107+
StrFunc<"llvm::cast<mlir::ShapedType>($" # name # ".getType()).getNumElements()"
1108+
"*llvm::cast<mlir::ShapedType>($" # name # ".getType()).getElementTypeBitWidth()">;
1109+
class AllMemSizesMatch<list<string> names> :
1110+
AllMatchSameOperatorTrait<names, SizeInBits<"_self">.result,
1111+
"size in bits">;
1112+
1113+
def XeGPU_CreateMemDescOp: XeGPU_Op<"create_mem_desc", [Pure,
1114+
AllMemSizesMatch<["source", "mem_desc"]>]> {
1115+
let summary = "Create a memory descriptor.";
1116+
let description = [{
1117+
Creates a memory descriptor from a shared local memory (SLM) buffer, and xegpu
1118+
specific memory layout. The resulting memory descriptor has to have the same size
1119+
as the underlying shared local memory.
1120+
1121+
Arguments:
1122+
- `source` : a 1D statically shaped memref with element type i8, representing the raw SLM buffer.
1123+
Results:
1124+
- `mem_desc` : the memory descriptor.
1125+
}];
1126+
let arguments = (ins StaticShared1DMemRefOf<[I8]>:$source);
1127+
let results = (outs XeGPU_MemDesc:$mem_desc);
1128+
let assemblyFormat = "$source prop-dict attr-dict `` `:` type($source) `->` qualified(type($mem_desc))";
1129+
}
1130+
1131+
def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
1132+
AllElementTypesMatch<["mem_desc", "res"]>,
1133+
AllRanksMatch<["mem_desc", "res"]>]> {
1134+
let arguments = (ins XeGPU_MemDesc:$mem_desc,
1135+
Variadic<Index>: $offsets,
1136+
DenseI64ArrayAttr: $const_offsets,
1137+
OptionalAttr<LayoutTrait>:$layout
1138+
);
1139+
let results = (outs XeGPU_ValueType:$res);
1140+
let assemblyFormat = [{
1141+
$mem_desc `` custom<DynamicIndexList>($offsets, $const_offsets)
1142+
prop-dict attr-dict `` `:` type(operands) `->` type(results)
1143+
}];
1144+
1145+
let description = [{
1146+
This operation loads a 2D block of data from shared local memory (SLM) as specified
1147+
by the provided 2D `mem_desc`. Only 2D memory descriptors are supported; use the
1148+
subview operation to obtain a compatible 2D `mem_desc` from a higher-rank descriptor if needed.
1149+
1150+
Arguments:
1151+
- `mem_desc`: the memory descriptor identifying the SLM region.
1152+
- `offsets`: the coordinates within the matrix to read from.
1153+
- `layout`: [optional] An attribute for guiding distributions among
1154+
subgroups and/or work-items. It currently can accept either
1155+
LayoutAttr or SliceAttr.
1156+
Results:
1157+
- `res`: the matrix elements loaded from SLM.
1158+
}];
1159+
1160+
let builders = [
1161+
OpBuilder<(ins "Type":$res, "TypedValue<MemDescType>": $mem_desc,
1162+
"llvm::ArrayRef<OpFoldResult>": $offsets, "LayoutTrait": $layout)>,
1163+
];
1164+
let extraClassDeclaration = [{
1165+
SmallVector<OpFoldResult> getMixedOffsets() {
1166+
return getMixedValues(getConstOffsets(), getOffsets(), getContext());
1167+
}
1168+
}];
1169+
1170+
let hasVerifier = 1;
1171+
}
1172+
1173+
def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>,
1174+
AllElementTypesMatch<["mem_desc", "data"]>,
1175+
AllRanksMatch<["mem_desc", "data"]>]> {
1176+
let arguments = (ins
1177+
XeGPU_ValueType:$data,
1178+
XeGPU_MemDesc:$mem_desc,
1179+
Variadic<Index>: $offsets,
1180+
DenseI64ArrayAttr: $const_offsets,
1181+
OptionalAttr<LayoutTrait>:$layout
1182+
);
1183+
let assemblyFormat = [{ $data `,` $mem_desc `` custom<DynamicIndexList>($offsets, $const_offsets)
1184+
prop-dict attr-dict `` `:` type(operands)}];
1185+
let description = [{
1186+
This operation stores a 2D `data` fragment into the shared local memory region
1187+
specified by a 2D `mem_desc`. Only 2D memory descriptors are supported; use the
1188+
subview operation to obtain a 2D `mem_desc` from a higher-rank descriptor if needed.
1189+
1190+
Arguments:
1191+
- `mem_desc`: the memory descriptor specifying the SLM region.
1192+
- `offsets`: the coordinates within the matrix where the data will be written.
1193+
- `data`: the values to be stored in the matrix.
1194+
- `layout`: [optional] An attribute for guiding distributions among
1195+
subgroups and/or work-items. It currently can accept either
1196+
LayoutAttr or SliceAttr.
1197+
}];
1198+
let builders = [
1199+
OpBuilder<(ins "Value" : $data, "TypedValue<MemDescType>": $mem_desc,
1200+
"llvm::ArrayRef<OpFoldResult>": $offsets, "LayoutTrait": $layout)>,
1201+
];
1202+
let extraClassDeclaration = [{
1203+
SmallVector<OpFoldResult> getMixedOffsets() {
1204+
return getMixedValues(getConstOffsets(), getOffsets(), getContext());
1205+
}
1206+
}];
1207+
1208+
let hasVerifier = 1;
1209+
}
1210+
1211+
def XeGPU_MemDescSubviewOp: XeGPU_Op<"mem_desc_subview",
1212+
[Pure, ViewLikeOpInterface, AllElementTypesMatch<["src", "res"]>]> {
1213+
let description = [{
1214+
Creates a subview of a memory descriptor. The resulting memory descriptor can have
1215+
a lower rank than the source; in this case, the result dimensions correspond to the
1216+
higher-order dimensions of the source memory descriptor.
1217+
1218+
Arguments:
1219+
- `src` : a memory descriptor.
1220+
- `offsets` : the coordinates within the matrix the subview will be created from.
1221+
1222+
Results:
1223+
- `res` : a memory descriptor with smaller size.
1224+
1225+
}];
1226+
let arguments = (ins XeGPU_MemDesc:$src,
1227+
Variadic<Index>:$offsets,
1228+
DenseI64ArrayAttr:$const_offsets);
1229+
let results = (outs XeGPU_MemDesc:$res);
1230+
let assemblyFormat = [{$src `` custom<DynamicIndexList>($offsets, $const_offsets) prop-dict
1231+
attr-dict `` `:` qualified(type($src)) `->` qualified(type($res))}];
1232+
let builders = [
1233+
OpBuilder<(ins "Type": $res, "Value":$src, "llvm::ArrayRef<OpFoldResult>": $offsets)>
1234+
];
1235+
1236+
let extraClassDeclaration = [{
1237+
mlir::Value getViewSource() { return getSrc(); }
1238+
1239+
SmallVector<OpFoldResult> getMixedOffsets() {
1240+
return getMixedValues(getConstOffsets(), getOffsets(), getContext());
1241+
}
1242+
}];
1243+
1244+
let hasVerifier = 1;
1245+
}
1246+
1247+
11001248
#endif // MLIR_DIALECT_XEGPU_IR_XEGPUOPS_TD

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,4 +201,53 @@ def XeGPU_Nbarrier: XeGPUTypeDef<"Nbarrier", "nbarrier", [], "mlir::Type"> {
201201
}];
202202
}
203203

204+
def XeGPU_MemDesc: XeGPUTypeDef<"MemDesc", "mem_desc", [ShapedTypeInterface], "mlir::Type"> {
205+
let summary = "MemDesc describing the data in SLM";
206+
let description = [{
207+
MemDesc represents a block of data stored in shared local memory.
208+
By default, unless a layout attribute is provided, the data is stored
209+
contiguously in row-major order within the region.
210+
211+
Examples:
212+
```mlir
213+
// A multi-dimensional array stored in column-major order.
214+
!xegpu.mem_desc<128x128xf16, #xegpu.mem_layout<stride = [1, 128]>>
215+
216+
// A multi-dimensional array stored in a blocked layout. Elements within the same block
217+
// are stored contiguously in memory. Blocks are stored in row-major order.
218+
!xegpu.mem_desc<128x128xf16, #xegpu.mem_layout<block = [8, 8]>>
219+
220+
// A multi-dimensional array stored in column-major order with blocked layout.
221+
!xegpu.mem_desc<128x128xf16, #xegpu.mem_layout<stride = [1, 128], block = [8, 8]>>
222+
```
223+
}];
224+
let parameters = (ins ArrayRefParameter<"int64_t">: $shape,
225+
"mlir::Type": $elementType,
226+
OptionalParameter<"MemLayoutAttr">: $mem_layout);
227+
228+
let extraClassDeclaration = [{
229+
bool hasRank() const { return true; }
230+
231+
MemDescType cloneWith(std::optional<llvm::ArrayRef<int64_t>> shape, Type elementType) const {
232+
return MemDescType::get(getContext(), shape.value_or(getShape()), elementType, getMemLayout());
233+
}
234+
235+
ArrayAttr getStrides() {
236+
auto layout = getMemLayout();
237+
if (layout && layout.hasAttr("stride")) {
238+
return layout.getStrides();
239+
}
240+
241+
// derive and return default strides
242+
SmallVector<int64_t> defaultStrides;
243+
llvm::append_range(defaultStrides, getShape().drop_front());
244+
llvm::append_values(defaultStrides, 1);
245+
Builder builder(getContext());
246+
return builder.getI64ArrayAttr(defaultStrides);
247+
}
248+
}];
249+
250+
let hasCustomAssemblyFormat = true;
251+
}
252+
204253
#endif // MLIR_DIALECT_XEGPU_IR_XEGPUTYPES_TD

mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ add_mlir_dialect_library(MLIRXeGPUDialect
1717
MLIRAffineUtils
1818
MLIRArithUtils
1919
MLIRDialectUtils
20+
MLIRGPUDialect
21+
MLIRXeVMDialect
2022
MLIRIR
2123
MLIRViewLikeInterface
2224
MLIRVectorDialect

0 commit comments

Comments
 (0)