Skip to content

Commit 68c4c83

Browse files
authored
[MLIR][XeGPU] Matrix load/store subgroup distribution (#165008)
1 parent ca00234 commit 68c4c83

File tree

11 files changed

+417
-119
lines changed

11 files changed

+417
-119
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,11 @@ class SliceAttr;
3030
} // namespace xegpu
3131
} // namespace mlir
3232

33+
// clang-format off
34+
#include <mlir/Dialect/XeGPU/IR/XeGPUEnums.h.inc>
3335
#include <mlir/Dialect/XeGPU/IR/XeGPUAttrInterface.h.inc>
3436
#include <mlir/Dialect/XeGPU/IR/XeGPUDialect.h.inc>
35-
#include <mlir/Dialect/XeGPU/IR/XeGPUEnums.h.inc>
37+
// clang-format on
3638

3739
#define GET_ATTRDEF_CLASSES
3840
#include <mlir/Dialect/XeGPU/IR/XeGPUAttrs.h.inc>

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -223,17 +223,17 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
223223
InterfaceMethod<"Derive a new layout by dropping InstData",
224224
"xegpu::DistributeLayoutAttr",
225225
"dropInstData">,
226-
InterfaceMethod<[{Delinearizes a linear subgroup ID into its multidimensional
227-
indices based on the effective subgroup layout.}],
226+
InterfaceMethod<[{Delinearizes a linear ID into its multidimensional
227+
indices based on the effective layout level.}],
228228
"FailureOr<SmallVector<Value>>",
229-
"delinearizeSubgroupId",
229+
"delinearizeId",
230230
(ins "OpBuilder &": $builder, "Location":$loc, "Value":$linearId)>,
231-
InterfaceMethod<[{Generates instructions to compute multidimensional offsets for blocks
232-
assigned to a subgroup identified by linearId. The shape parameter
233-
represents the workgroup-level problem size. Each subgroup may access
231+
InterfaceMethod<[{Generates instructions to compute multidimensional coordinates for dist units
232+
assigned to a level identified by linearId. The shape parameter
233+
represents the higher-level problem size. Each level may access
234234
multiple blocks according to round-robin distribution rules.}],
235235
"FailureOr<SmallVector<SmallVector<Value>>>",
236-
"getOffsets",
236+
"computeDistributedCoords",
237237
(ins "OpBuilder &": $builder, "Location":$loc, "Value":$linearId, "ArrayRef<int64_t>":$shape)>,
238238
InterfaceMethod</*desc=*/[{Check if this layout can be achieved by applying a transpose
239239
to some other layout according to given permutation of (0...n-1).}],
@@ -476,17 +476,17 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> {
476476
return {};
477477
}
478478

479-
/// Delinearizes a linear subgroup ID into its multidimensional indices
480-
/// based on the effective subgroup layout.
479+
/// Delinearizes a linear ID into its multidimensional indices
480+
/// based on the effective level of the layout.
481481
FailureOr<SmallVector<Value>>
482-
delinearizeSubgroupId(OpBuilder &builder, Location loc, Value linearId);
482+
delinearizeId(OpBuilder &builder, Location loc, Value linearId);
483483

484-
/// Generates instructions to compute multidimensional offsets for blocks
485-
/// assigned to a subgroup identified by linearId. The shape parameter
486-
/// represents the workgroup-level problem size. Each subgroup may access
484+
/// Generates instructions to compute multidimensional coordinates for dist units
485+
/// assigned to a level identified by linearId. The shape parameter
486+
/// represents the higher-level problem size. Each `level` may access
487487
/// multiple blocks according to round-robin distribution rules.
488488
FailureOr<SmallVector<SmallVector<Value>>>
489-
getOffsets(OpBuilder &builder, Location loc, Value linearId, ArrayRef<int64_t> shape);
489+
computeDistributedCoords(OpBuilder &builder, Location loc, Value linearId, ArrayRef<int64_t> shape);
490490

491491
/// Check if this is slice of some other layout.
492492
bool isSliceOf(const xegpu::DistributeLayoutAttr &other) { return false; }
@@ -643,14 +643,15 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {
643643
/// Delinearizes a linear subgroup ID into its multidimensional indices
644644
/// based on the effective subgroup layout.
645645
FailureOr<SmallVector<Value>>
646-
delinearizeSubgroupId(OpBuilder &builder, Location loc, Value linearId);
646+
delinearizeId(OpBuilder &builder, Location loc, Value linearId);
647647

648-
/// Generates instructions to compute multidimensional offsets for blocks
648+
/// Generates instructions to compute multidimensional coordinates for blocks
649649
/// assigned to a subgroup identified by linearId. The shape parameter
650650
/// represents the workgroup-level problem size. Each subgroup may access
651651
/// multiple blocks according to round-robin distribution rules.
652+
652653
FailureOr<SmallVector<SmallVector<Value>>>
653-
getOffsets(OpBuilder &builder, Location loc, Value linearId, ArrayRef<int64_t> shape);
654+
computeDistributedCoords(OpBuilder &builder, Location loc, Value linearId, ArrayRef<int64_t> shape);
654655

655656
/// Check if this is slice of some other layout.
656657
bool isSliceOf(const xegpu::DistributeLayoutAttr &other);

mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def XeGPUSubgroupDistribute : Pass<"xegpu-subgroup-distribute"> {
2626
The pass distributes subgroup level (SIMD) XeGPU ops to work items.
2727
}];
2828
let dependentDialects = ["memref::MemRefDialect", "xegpu::XeGPUDialect",
29-
"vector::VectorDialect"];
29+
"vector::VectorDialect", "index::IndexDialect"];
3030
}
3131

3232
def XeGPUPropagateLayout : Pass<"xegpu-propagate-layout"> {

mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -562,6 +562,8 @@ class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> {
562562
VectorType valOrResVecTy = dyn_cast<VectorType>(data.getType());
563563
if (!valOrResVecTy)
564564
valOrResVecTy = VectorType::get(1, data.getType());
565+
if (valOrResVecTy.getShape().size() != 1)
566+
return rewriter.notifyMatchFailure(op, "Expected 1D data vector.");
565567

566568
int64_t elemBitWidth =
567569
valOrResVecTy.getElementType().getIntOrFloatBitWidth();

mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp

Lines changed: 91 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -37,55 +37,61 @@ void XeGPUDialect::initialize() {
3737
>();
3838
}
3939

40-
/// Generates instructions to compute offsets for a subgroup identified by
41-
/// its multidimensional indices (sgId), using the specified subgroup layout
42-
/// (sgLayout), subgroup data dimensions (sizePerSg), and the overall data
43-
/// dimensions (sizePerWg).
40+
// A `srcShape` consists of N distribution units, each being `subShapesLayout` x
41+
// `subShape`. A `delinearizedId` is used to identify a particular `subShape`
42+
// within each distribution unit.
43+
// Example:
44+
// WG data is 128x256. SG data is 16x32, in 4x2 layout, this gives a
45+
// distribution unit of shape 64x64, we have 2x4 such distribution units.
46+
// `delinearizedId` is used to identify a 16x32 of a subgroup in each
47+
// distribution unit.
4448
static SmallVector<SmallVector<Value>>
45-
genOffsetsComputingInsts(OpBuilder &builder, Location loc,
46-
SmallVector<Value> sgId, ArrayRef<int64_t> sgLayout,
47-
ArrayRef<int64_t> sizePerSg,
48-
ArrayRef<int64_t> sizePerWg) {
49-
50-
SmallVector<SmallVector<Value>> offsets;
49+
genCoordinates(OpBuilder &builder, Location loc,
50+
SmallVector<Value> delinearizedId,
51+
ArrayRef<int64_t> subShapesLayout, ArrayRef<int64_t> subShape,
52+
ArrayRef<int64_t> srcShape) {
53+
SmallVector<SmallVector<Value>> coordinates;
54+
55+
// A distribution unit must be less than or equal to `srcShape`
56+
SmallVector<int64_t> distUnitShape = llvm::map_to_vector(
57+
llvm::zip_equal(srcShape,
58+
computeElementwiseMul(subShapesLayout, subShape)),
59+
[](const auto &t) { return std::min(std::get<0>(t), std::get<1>(t)); });
5160

52-
// nd local offset, localOffset[i] = sgId[i] * sizePerSg[i]
53-
SmallVector<Value> localOffsets = llvm::map_to_vector(
54-
llvm::zip(sgId, sizePerSg), [&](const auto &t) -> Value {
61+
// Get the offset of `subShape` within a distribution unit.
62+
SmallVector<Value> distUnitLocalOffset = llvm::map_to_vector(
63+
llvm::zip(delinearizedId, subShape), [&](const auto &t) -> Value {
5564
return builder.createOrFold<index::MulOp>(
5665
loc, std::get<0>(t),
5766
builder.createOrFold<arith::ConstantIndexOp>(loc, std::get<1>(t)));
5867
});
5968

60-
// distUnit[i] is the minimum value between sizePerWg[i] and
61-
// sgLayout[i] * sizePerSg[i]
62-
SmallVector<int64_t> distUnit = llvm::map_to_vector(
63-
llvm::zip_equal(sizePerWg, computeElementwiseMul(sgLayout, sizePerSg)),
64-
[](const auto &t) { return std::min(std::get<0>(t), std::get<1>(t)); });
65-
69+
// For each dist unit
6670
for (SmallVector<int64_t> unitOffs :
67-
StaticTileOffsetRange(sizePerWg, distUnit)) {
71+
StaticTileOffsetRange(srcShape, distUnitShape)) {
72+
// Get dist unit offset within `srcShape`.
6873
SmallVector<Value> base =
6974
llvm::map_to_vector(unitOffs, [&](int64_t d) -> Value {
7075
return arith::ConstantIndexOp::create(builder, loc, d);
7176
});
72-
73-
SmallVector<Value> adds = llvm::map_to_vector(
74-
llvm::zip_equal(base, localOffsets), [&](const auto &t) -> Value {
75-
return builder.createOrFold<arith::AddIOp>(loc, std::get<0>(t),
76-
std::get<1>(t));
77-
});
78-
77+
// Calculate `subShape` offset within `srcShape`.
78+
SmallVector<Value> adds =
79+
llvm::map_to_vector(llvm::zip_equal(base, distUnitLocalOffset),
80+
[&](const auto &t) -> Value {
81+
return builder.createOrFold<arith::AddIOp>(
82+
loc, std::get<0>(t), std::get<1>(t));
83+
});
84+
// Do not go beyond `srcShape` bounds.
7985
SmallVector<Value> mods = llvm::map_to_vector(
80-
llvm::zip_equal(adds, sizePerWg), [&](const auto &t) -> Value {
86+
llvm::zip_equal(adds, srcShape), [&](const auto &t) -> Value {
8187
return builder.createOrFold<index::RemUOp>(
8288
loc, std::get<0>(t),
8389
arith::ConstantIndexOp::create(builder, loc, std::get<1>(t)));
8490
});
8591

86-
offsets.push_back(mods);
92+
coordinates.push_back(mods);
8793
}
88-
return offsets;
94+
return coordinates;
8995
}
9096

9197
// Checks if the given shape can be evenly distributed based on the layout
@@ -272,12 +278,7 @@ LayoutAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
272278
}
273279

274280
FailureOr<SmallVector<Value>>
275-
LayoutAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc,
276-
Value linearId) {
277-
// delinearizeSubgroupId is only available for
278-
// workgroup-level layout attribute
279-
if (!isForWorkgroup())
280-
return failure();
281+
LayoutAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId) {
281282

282283
// TODO: handle order attribute
283284
auto hasDefaultOrder = [&]() {
@@ -287,41 +288,52 @@ LayoutAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc,
287288
};
288289
if (!hasDefaultOrder())
289290
return mlir::emitError(loc, "order attribute is currently not supported.");
290-
291-
auto dims =
292-
llvm::map_to_vector(getEffectiveSgLayoutAsInt(), [&](int64_t d) -> Value {
293-
return builder.createOrFold<arith::ConstantIndexOp>(loc, d);
294-
});
291+
SmallVector<int64_t> layout;
292+
if (isForWorkgroup()) {
293+
layout = getEffectiveSgLayoutAsInt();
294+
} else if (isForSubgroup()) {
295+
layout = getEffectiveLaneLayoutAsInt();
296+
} else {
297+
return failure();
298+
}
299+
auto dims = llvm::map_to_vector(layout, [&](int64_t d) -> Value {
300+
return builder.createOrFold<arith::ConstantIndexOp>(loc, d);
301+
});
295302

296303
return affine::delinearizeIndex(builder, loc, linearId, dims);
297304
}
298305

299-
/// Implements DistributeLayoutAttr::getOffsets to generate
306+
/// Implements DistributeLayoutAttr::computeDistributedCoords to generate
300307
/// instructions for computing multi-dimensional offsets when distributed by
301308
/// LayoutAttr.
302309
FailureOr<SmallVector<SmallVector<Value>>>
303-
LayoutAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
304-
ArrayRef<int64_t> shape) {
305-
if (!isForWorkgroup())
310+
LayoutAttr::computeDistributedCoords(OpBuilder &builder, Location loc,
311+
Value linearId, ArrayRef<int64_t> shape) {
312+
SmallVector<int64_t> layout;
313+
SmallVector<int64_t> subShape;
314+
if (isForWorkgroup()) {
315+
layout = getEffectiveSgLayoutAsInt();
316+
subShape = getEffectiveSgDataAsInt();
317+
} else if (isForSubgroup()) {
318+
layout = getEffectiveLaneLayoutAsInt();
319+
subShape = getEffectiveLaneDataAsInt();
320+
} else {
306321
return failure();
307-
308-
SmallVector<int64_t> sgLayout = getEffectiveSgLayoutAsInt();
309-
SmallVector<int64_t> sgShape = getEffectiveSgDataAsInt();
310-
if (sgShape.empty()) {
311-
if (auto derivedShape = computeShapeRatio(shape, sgLayout))
312-
sgShape = derivedShape.value();
322+
}
323+
if (subShape.empty()) {
324+
if (auto derivedShape = computeShapeRatio(shape, layout))
325+
subShape = derivedShape.value();
313326
else
314327
return failure();
315328
}
316329

317330
// delinearize Ids
318-
auto maybeIds = delinearizeSubgroupId(builder, loc, linearId);
331+
auto maybeIds = delinearizeId(builder, loc, linearId);
319332
if (failed(maybeIds))
320333
return failure();
321-
SmallVector<Value> sgIds = *maybeIds;
334+
SmallVector<Value> ids = *maybeIds;
322335

323-
return genOffsetsComputingInsts(builder, loc, sgIds, sgLayout, sgShape,
324-
shape);
336+
return genCoordinates(builder, loc, ids, layout, subShape, shape);
325337
}
326338

327339
//===----------------------------------------------------------------------===//
@@ -375,34 +387,43 @@ SliceAttr SliceAttr::flatten() const {
375387
}
376388

377389
FailureOr<SmallVector<Value>>
378-
SliceAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc,
379-
Value linearId) {
390+
SliceAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId) {
380391
SliceAttr attr = flatten();
381392
auto parent = dyn_cast<LayoutAttr>(attr.getParent());
382-
return parent.delinearizeSubgroupId(builder, loc, linearId);
393+
return parent.delinearizeId(builder, loc, linearId);
383394
}
384395

385-
/// Implements DistributeLayoutAttr::getOffsets to generate
386-
/// instructions for computing multi-dimensional offsets when distributed by
387-
/// SliceAttr.
396+
// Implements DistributeLayoutAttr::computeDistributedCoords to generate
397+
// instructions for computing multi-dimensional offsets when distributed by
398+
// LayoutAttr.
388399
FailureOr<SmallVector<SmallVector<Value>>>
389-
SliceAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
390-
ArrayRef<int64_t> shape) {
400+
SliceAttr::computeDistributedCoords(OpBuilder &builder, Location loc,
401+
Value linearId, ArrayRef<int64_t> shape) {
391402
assert(getRank() == static_cast<int64_t>(shape.size()) && "invalid shape.");
392403
if (!isForWorkgroup())
393404
return failure();
394405

395-
SmallVector<int64_t> sgLayout = getEffectiveSgLayoutAsInt();
396-
SmallVector<int64_t> sgShape = getEffectiveSgDataAsInt();
397-
if (sgShape.empty()) {
398-
if (auto derivedShape = computeShapeRatio(shape, sgLayout))
399-
sgShape = derivedShape.value();
406+
SmallVector<int64_t> layout;
407+
SmallVector<int64_t> subShape;
408+
if (isForWorkgroup()) {
409+
layout = getEffectiveSgLayoutAsInt();
410+
subShape = getEffectiveSgDataAsInt();
411+
} else if (isForSubgroup()) {
412+
layout = getEffectiveLaneLayoutAsInt();
413+
subShape = getEffectiveLaneDataAsInt();
414+
} else {
415+
return failure();
416+
}
417+
418+
if (subShape.empty()) {
419+
if (auto derivedShape = computeShapeRatio(shape, layout))
420+
subShape = derivedShape.value();
400421
else
401422
return failure();
402423
}
403424

404425
// delinearize Ids
405-
auto maybeIds = delinearizeSubgroupId(builder, loc, linearId);
426+
auto maybeIds = delinearizeId(builder, loc, linearId);
406427
if (failed(maybeIds))
407428
return failure();
408429

@@ -412,8 +433,7 @@ SliceAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
412433
SmallVector<Value> sgIds =
413434
XeGPUDialect::slice(ArrayRef<Value>(*maybeIds), dims);
414435

415-
return genOffsetsComputingInsts(builder, loc, sgIds, sgLayout, sgShape,
416-
shape);
436+
return genCoordinates(builder, loc, sgIds, layout, subShape, shape);
417437
}
418438

419439
bool SliceAttr::isSliceOf(const xegpu::DistributeLayoutAttr &other) {

0 commit comments

Comments
 (0)