Skip to content

Commit 6cae29f

Browse files
authored
[MLIR][XeGPU] XeVM lowering support for load_matrix/store_matrix (#162780)
This PR adds lowering of xegpu.load_matrix/store_matrix to xevm.blockload/blockstore or and llvm.load/store, depending on wi level attributes. It includes a few components: 1. adds wi-level attributes: subgroup_block_io. 2. expand load_matrix/store_matrix op definition to support scalar data (besides vector data). 2. adds a member function to mem_desc to compute the linearized address for a nd offsets. 3. add lowering depending on wi-level attributes: a) if subgroup_block_io attribute presents, lower to xevm.blockload/blockstore c) else lower to llvm.load/store. If result is a vector, lower to llvm.load/store with vector operand.
1 parent fd08af0 commit 6cae29f

File tree

13 files changed

+715
-163
lines changed

13 files changed

+715
-163
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -712,10 +712,14 @@ def XeGPU_MemLayoutAttr : XeGPUAttr<"MemLayout", "mem_layout"> {
712712
return getAttrs().contains(name);
713713
}
714714

715-
ArrayAttr getStrides() {
715+
ArrayAttr getStrideAttr() {
716716
return getAttrs().getAs<ArrayAttr>("stride");
717717
}
718718

719+
ArrayAttr getBlockAttr() {
720+
return getAttrs().getAs<ArrayAttr>("block");
721+
}
722+
719723
}];
720724

721725
}

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td

Lines changed: 20 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1298,14 +1298,14 @@ def XeGPU_CreateMemDescOp: XeGPU_Op<"create_mem_desc", [Pure,
12981298
}
12991299

13001300
def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
1301-
AllElementTypesMatch<["mem_desc", "res"]>,
1302-
AllRanksMatch<["mem_desc", "res"]>]> {
1301+
AllElementTypesMatch<["mem_desc", "res"]>]> {
13031302
let arguments = (ins XeGPU_MemDesc:$mem_desc,
13041303
Variadic<Index>: $offsets,
13051304
DenseI64ArrayAttr: $const_offsets,
1305+
OptionalAttr<UnitAttr>:$subgroup_block_io,
13061306
OptionalAttr<DistributeLayoutAttr>:$layout
13071307
);
1308-
let results = (outs XeGPU_ValueType:$res);
1308+
let results = (outs AnyTypeOf<[XeGPU_ValueType, XeGPU_ScalarType]>:$res);
13091309
let assemblyFormat = [{
13101310
$mem_desc `` custom<DynamicIndexList>($offsets, $const_offsets)
13111311
prop-dict attr-dict `` `:` type(operands) `->` type(results)
@@ -1319,6 +1319,9 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
13191319
Arguments:
13201320
- `mem_desc`: the memory descriptor identifying the SLM region.
13211321
- `offsets`: the coordinates within the matrix to read from.
1322+
- `subgroup_block_io`: [optional] An attribute indicating that the operation can be
1323+
lowered to a subgroup block load. When this attribute is present,
1324+
the offsets are subgroup-uniform across all lanes.
13221325
- `layout`: [optional] An attribute for guiding distributions among
13231326
subgroups and/or work-items. It currently can accept either
13241327
LayoutAttr or SliceAttr.
@@ -1336,21 +1339,24 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
13361339
}
13371340

13381341
ArrayRef<int64_t> getDataShape() {
1339-
return getRes().getType().getShape();
1342+
auto resTy = getRes().getType();
1343+
if (auto vecTy = llvm::dyn_cast<VectorType>(resTy))
1344+
return vecTy.getShape();
1345+
return {};
13401346
}
13411347
}];
13421348

13431349
let hasVerifier = 1;
13441350
}
13451351

13461352
def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>,
1347-
AllElementTypesMatch<["mem_desc", "data"]>,
1348-
AllRanksMatch<["mem_desc", "data"]>]> {
1353+
AllElementTypesMatch<["mem_desc", "data"]>]> {
13491354
let arguments = (ins
1350-
XeGPU_ValueType:$data,
1355+
AnyTypeOf<[XeGPU_ValueType, XeGPU_ScalarType]>:$data,
13511356
XeGPU_MemDesc:$mem_desc,
13521357
Variadic<Index>: $offsets,
13531358
DenseI64ArrayAttr: $const_offsets,
1359+
OptionalAttr<UnitAttr>:$subgroup_block_io,
13541360
OptionalAttr<DistributeLayoutAttr>:$layout
13551361
);
13561362
let assemblyFormat = [{ $data `,` $mem_desc `` custom<DynamicIndexList>($offsets, $const_offsets)
@@ -1364,6 +1370,9 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>,
13641370
- `mem_desc`: the memory descriptor specifying the SLM region.
13651371
- `offsets`: the coordinates within the matrix where the data will be written.
13661372
- `data`: the values to be stored in the matrix.
1373+
- `subgroup_block_io`: [optional] An attribute indicating that the operation can be
1374+
lowered to a subgroup block store. When this attribute is present,
1375+
the offsets are subgroup-uniform across all lanes.
13671376
- `layout`: [optional] An attribute for guiding distributions among
13681377
subgroups and/or work-items. It currently can accept either
13691378
LayoutAttr or SliceAttr.
@@ -1378,49 +1387,15 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>,
13781387
}
13791388

13801389
ArrayRef<int64_t> getDataShape() {
1381-
return getData().getType().getShape();
1390+
auto DataTy = getData().getType();
1391+
if (auto vecTy = llvm::dyn_cast<VectorType>(DataTy))
1392+
return vecTy.getShape();
1393+
return {};
13821394
}
13831395

13841396
}];
13851397

13861398
let hasVerifier = 1;
13871399
}
13881400

1389-
def XeGPU_MemDescSubviewOp: XeGPU_Op<"mem_desc_subview",
1390-
[Pure, ViewLikeOpInterface, AllElementTypesMatch<["src", "res"]>]> {
1391-
let description = [{
1392-
Creates a subview of a memory descriptor. The resulting memory descriptor can have
1393-
a lower rank than the source; in this case, the result dimensions correspond to the
1394-
higher-order dimensions of the source memory descriptor.
1395-
1396-
Arguments:
1397-
- `src` : a memory descriptor.
1398-
- `offsets` : the coordinates within the matrix the subview will be created from.
1399-
1400-
Results:
1401-
- `res` : a memory descriptor with smaller size.
1402-
1403-
}];
1404-
let arguments = (ins XeGPU_MemDesc:$src,
1405-
Variadic<Index>:$offsets,
1406-
DenseI64ArrayAttr:$const_offsets);
1407-
let results = (outs XeGPU_MemDesc:$res);
1408-
let assemblyFormat = [{$src `` custom<DynamicIndexList>($offsets, $const_offsets) prop-dict
1409-
attr-dict `` `:` qualified(type($src)) `->` qualified(type($res))}];
1410-
let builders = [
1411-
OpBuilder<(ins "Type": $res, "Value":$src, "llvm::ArrayRef<OpFoldResult>": $offsets)>
1412-
];
1413-
1414-
let extraClassDeclaration = [{
1415-
mlir::Value getViewSource() { return getSrc(); }
1416-
1417-
SmallVector<OpFoldResult> getMixedOffsets() {
1418-
return getMixedValues(getConstOffsets(), getOffsets(), getContext());
1419-
}
1420-
}];
1421-
1422-
let hasVerifier = 1;
1423-
}
1424-
1425-
14261401
#endif // MLIR_DIALECT_XEGPU_IR_XEGPUOPS_TD

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td

Lines changed: 59 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -237,19 +237,75 @@ def XeGPU_MemDesc: XeGPUTypeDef<"MemDesc", "mem_desc", [ShapedTypeInterface], "m
237237
return MemDescType::get(getContext(), shape.value_or(getShape()), elementType, getMemLayout());
238238
}
239239

240-
ArrayAttr getStrides() {
240+
ArrayAttr getStrideAttr() {
241241
auto layout = getMemLayout();
242242
if (layout && layout.hasAttr("stride")) {
243-
return layout.getStrides();
243+
return layout.getStrideAttr();
244244
}
245-
246245
// derive and return default strides
247246
SmallVector<int64_t> defaultStrides;
248247
llvm::append_range(defaultStrides, getShape().drop_front());
249248
llvm::append_values(defaultStrides, 1);
250249
Builder builder(getContext());
251250
return builder.getI64ArrayAttr(defaultStrides);
252251
}
252+
253+
ArrayAttr getBlockAttr() {
254+
auto layout = getMemLayout();
255+
if (layout && layout.hasAttr("block")) {
256+
return layout.getBlockAttr();
257+
}
258+
Builder builder(getContext());
259+
return builder.getI64ArrayAttr({});
260+
}
261+
262+
/// Heuristic to determine if the MemDesc uses column-major layout,
263+
/// based on the rank and the value of the first stride dimension.
264+
bool isColMajor() {
265+
auto dim0 = dyn_cast<IntegerAttr>(getStrideAttr()[0]);
266+
return getRank() == 2 && dim0.getInt() == 1;
267+
}
268+
269+
// Get the Blocking shape for a MemDescType, Which is represented
270+
// as an attribute in MemDescType. By default it is the shape
271+
// of the mdescTy
272+
SmallVector<int64_t> getBlockShape() {
273+
SmallVector<int64_t> size(getShape());
274+
ArrayAttr blockAttr = getBlockAttr();
275+
if (!blockAttr.empty()) {
276+
size.clear();
277+
for (auto attr : blockAttr.getValue()) {
278+
size.push_back(cast<IntegerAttr>(attr).getInt());
279+
}
280+
}
281+
return size;
282+
}
283+
284+
// Get strides as vector of integer.
285+
// If it contains block attribute, the strides are blocked strides.
286+
//
287+
// The blocking is applied to the base matrix shape derived from the
288+
// memory descriptor's stride information. If the matrix described by
289+
// the memory descriptor is not contiguous, it is assumed that the base
290+
// matrix is contiguous and follows the same memory layout.
291+
//
292+
// It first computes the original matrix shape using the stride info,
293+
// then computes the number of blocks in each dimension of original shape,
294+
// then compute the outer block shape and stride,
295+
// then combines the inner and outer block shape and stride
296+
// e.g. for `mem_desc<32x256xf16, @block=[16, 8], @strides=[1, 32]>`
297+
// its memory layout tuple is ([2,32,16,8],[128,256,1,16])
298+
// for `mem_desc<256x32xf16, @block=[8, 16]>` with default @stride[32, 1]
299+
// its memory layout tuple is ([32,2,8,16],[256,128,16,1])
300+
SmallVector<int64_t> getStrideShape();
301+
302+
/// Generates instructions to compute the linearize offset
303+
// if the memory descriptor is blocked, it returns linearize offset based on the blocked layout
304+
// the strides of memory descriptor is always considered regardless of blocked or not
305+
Value getLinearOffsets(OpBuilder &builder,
306+
Location loc, ArrayRef<OpFoldResult> offsets);
307+
308+
253309
}];
254310

255311
let hasCustomAssemblyFormat = true;

mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ add_mlir_conversion_library(MLIRXeGPUToXeVM
2121
MLIRIndexDialect
2222
MLIRSCFDialect
2323
MLIRXeGPUDialect
24+
MLIRXeGPUUtils
2425
MLIRPass
2526
MLIRTransforms
2627
MLIRSCFTransforms

0 commit comments

Comments
 (0)