Skip to content

Commit f80ee32

Browse files
committed
Add offset calculation
1 parent 887f978 commit f80ee32

File tree

8 files changed

+166
-113
lines changed

8 files changed

+166
-113
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,11 @@ class SliceAttr;
3030
} // namespace xegpu
3131
} // namespace mlir
3232

33+
// clang-format off
34+
#include <mlir/Dialect/XeGPU/IR/XeGPUEnums.h.inc>
3335
#include <mlir/Dialect/XeGPU/IR/XeGPUAttrInterface.h.inc>
3436
#include <mlir/Dialect/XeGPU/IR/XeGPUDialect.h.inc>
35-
#include <mlir/Dialect/XeGPU/IR/XeGPUEnums.h.inc>
37+
// clang-format on
3638

3739
#define GET_ATTRDEF_CLASSES
3840
#include <mlir/Dialect/XeGPU/IR/XeGPUAttrs.h.inc>

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td

Lines changed: 29 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,16 @@ def XeGPU_FenceScope: I32EnumAttr<"FenceScope",
167167
let cppNamespace = "::mlir::xegpu";
168168
}
169169

170+
def XeGPU_WGLevel: I32EnumAttrCase<"WG", 0, "wg">;
171+
def XeGPU_SGLevel: I32EnumAttrCase<"SG", 1, "sg">;
172+
def XeGPU_WILevel: I32EnumAttrCase<"WI", 2, "wi">;
173+
def XeGPU_DistributionLevel: I32EnumAttr<"DistributionLevel",
174+
"The enumeration for the scope of fence operation.",
175+
[XeGPU_WGLevel, XeGPU_SGLevel, XeGPU_WILevel]> {
176+
let genSpecializedAttr = 0;
177+
let cppNamespace = "::mlir::xegpu";
178+
}
179+
170180
def XeGPU_FenceScopeAttr:
171181
EnumAttr<XeGPU_Dialect, XeGPU_FenceScope, "fence_scope"> {
172182
let summary = [{Describes the scope of fence.
@@ -223,18 +233,18 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
223233
InterfaceMethod<"Derive a new layout by dropping InstData",
224234
"xegpu::DistributeLayoutAttr",
225235
"dropInstData">,
226-
InterfaceMethod<[{Delinearizes a linear subgroup ID into its multidimensional
227-
indices based on the effective subgroup layout.}],
236+
InterfaceMethod<[{Delinearizes a linear ID into its multidimensional
237+
indices based on the effective `level` layout.}],
228238
"FailureOr<SmallVector<Value>>",
229-
"delinearizeSubgroupId",
230-
(ins "OpBuilder &": $builder, "Location":$loc, "Value":$linearId)>,
231-
InterfaceMethod<[{Generates instructions to compute multidimensional offsets for blocks
232-
assigned to a subgroup identified by linearId. The shape parameter
233-
represents the workgroup-level problem size. Each subgroup may access
239+
"delinearizeId",
240+
(ins "OpBuilder &": $builder, "Location":$loc, "Value":$linearId, "xegpu::DistributionLevel": $level)>,
241+
InterfaceMethod<[{Generates instructions to compute multidimensional offsets for dist units
242+
assigned to a `level` identified by linearId. The shape parameter
243+
represents the higher-level problem size. Each `level` may access
234244
multiple blocks according to round-robin distribution rules.}],
235245
"FailureOr<SmallVector<SmallVector<Value>>>",
236-
"getOffsets",
237-
(ins "OpBuilder &": $builder, "Location":$loc, "Value":$linearId, "ArrayRef<int64_t>":$shape)>,
246+
"computeDistributedCoords",
247+
(ins "OpBuilder &": $builder, "Location":$loc, "Value":$linearId, "ArrayRef<int64_t>":$shape, "xegpu::DistributionLevel": $level)>,
238248
InterfaceMethod</*desc=*/[{Check if this layout can be achieved by applying a transpose
239249
to some other layout according to given permutation of (0...n-1).}],
240250
/*retTy=*/"bool",
@@ -476,17 +486,17 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> {
476486
return {};
477487
}
478488

479-
/// Delinearizes a linear subgroup ID into its multidimensional indices
480-
/// based on the effective subgroup layout.
489+
/// Delinearizes a linear ID into its multidimensional indices
490+
/// based on the effective `level` layout.
481491
FailureOr<SmallVector<Value>>
482-
delinearizeSubgroupId(OpBuilder &builder, Location loc, Value linearId);
492+
delinearizeId(OpBuilder &builder, Location loc, Value linearId, xegpu::DistributionLevel level);
483493

484-
/// Generates instructions to compute multidimensional offsets for blocks
485-
/// assigned to a subgroup identified by linearId. The shape parameter
486-
/// represents the workgroup-level problem size. Each subgroup may access
494+
/// Generates instructions to compute multidimensional offsets for dist units
495+
/// assigned to a `level` identified by linearId. The shape parameter
496+
/// represents the higher-level problem size. Each `level` may access
487497
/// multiple blocks according to round-robin distribution rules.
488498
FailureOr<SmallVector<SmallVector<Value>>>
489-
getOffsets(OpBuilder &builder, Location loc, Value linearId, ArrayRef<int64_t> shape);
499+
computeDistributedCoords(OpBuilder &builder, Location loc, Value linearId, ArrayRef<int64_t> shape, xegpu::DistributionLevel level);
490500

491501
/// Check if this is slice of some other layout.
492502
bool isSliceOf(const xegpu::DistributeLayoutAttr &other) { return false; }
@@ -643,14 +653,15 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {
643653
/// Delinearizes a linear subgroup ID into its multidimensional indices
644654
/// based on the effective subgroup layout.
645655
FailureOr<SmallVector<Value>>
646-
delinearizeSubgroupId(OpBuilder &builder, Location loc, Value linearId);
656+
delinearizeId(OpBuilder &builder, Location loc, Value linearId, xegpu::DistributionLevel level);
647657

648658
/// Generates instructions to compute multidimensional offsets for blocks
649659
/// assigned to a subgroup identified by linearId. The shape parameter
650660
/// represents the workgroup-level problem size. Each subgroup may access
651661
/// multiple blocks according to round-robin distribution rules.
662+
652663
FailureOr<SmallVector<SmallVector<Value>>>
653-
getOffsets(OpBuilder &builder, Location loc, Value linearId, ArrayRef<int64_t> shape);
664+
computeDistributedCoords(OpBuilder &builder, Location loc, Value linearId,ArrayRef<int64_t> shape, xegpu::DistributionLevel level);
654665

655666
/// Check if this is slice of some other layout.
656667
bool isSliceOf(const xegpu::DistributeLayoutAttr &other);

mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def XeGPUSubgroupDistribute : Pass<"xegpu-subgroup-distribute"> {
2626
The pass distributes subgroup level (SIMD) XeGPU ops to work items.
2727
}];
2828
let dependentDialects = ["memref::MemRefDialect", "xegpu::XeGPUDialect",
29-
"vector::VectorDialect"];
29+
"vector::VectorDialect", "index::IndexDialect"];
3030
}
3131

3232
def XeGPUPropagateLayout : Pass<"xegpu-propagate-layout"> {

mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp

Lines changed: 86 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -38,47 +38,47 @@ void XeGPUDialect::initialize() {
3838
>();
3939
}
4040

41-
/// Generates instructions to compute offsets for a subgroup identified by
42-
/// its multidimensional indices (sgId), using the specified subgroup layout
43-
/// (sgLayout), subgroup data dimensions (sizePerSg), and the overall data
44-
/// dimensions (sizePerWg).
41+
// A `srcShape` consists of N distribution units, each being `subShapesLayout` x
42+
// `subShape`. A `delinearizedId` is used to identify a particular `subShape`
43+
// within each distribution unit.
4544
static SmallVector<SmallVector<Value>>
46-
genOffsetsComputingInsts(OpBuilder &builder, Location loc,
47-
SmallVector<Value> sgId, ArrayRef<int64_t> sgLayout,
48-
ArrayRef<int64_t> sizePerSg,
49-
ArrayRef<int64_t> sizePerWg) {
50-
45+
genOffsets(OpBuilder &builder, Location loc, SmallVector<Value> delinearizedId,
46+
ArrayRef<int64_t> subShapesLayout, ArrayRef<int64_t> subShape,
47+
ArrayRef<int64_t> srcShape) {
5148
SmallVector<SmallVector<Value>> offsets;
5249

53-
// nd local offset, localOffset[i] = sgId[i] * sizePerSg[i]
54-
SmallVector<Value> localOffsets = llvm::map_to_vector(
55-
llvm::zip(sgId, sizePerSg), [&](const auto &t) -> Value {
50+
// A distribution unit must be less than or equal to `srcShape`
51+
SmallVector<int64_t> distUnitShape = llvm::map_to_vector(
52+
llvm::zip_equal(srcShape,
53+
computeElementwiseMul(subShapesLayout, subShape)),
54+
[](const auto &t) { return std::min(std::get<0>(t), std::get<1>(t)); });
55+
56+
// Get the offset of `subShape` within a distribution unit.
57+
SmallVector<Value> distUnitLocalOffset = llvm::map_to_vector(
58+
llvm::zip(delinearizedId, subShape), [&](const auto &t) -> Value {
5659
return builder.createOrFold<index::MulOp>(
5760
loc, std::get<0>(t),
5861
builder.createOrFold<arith::ConstantIndexOp>(loc, std::get<1>(t)));
5962
});
6063

61-
// distUnit[i] is the minimum value between sizePerWg[i] and
62-
// sgLayout[i] * sizePerSg[i]
63-
SmallVector<int64_t> distUnit = llvm::map_to_vector(
64-
llvm::zip_equal(sizePerWg, computeElementwiseMul(sgLayout, sizePerSg)),
65-
[](const auto &t) { return std::min(std::get<0>(t), std::get<1>(t)); });
66-
64+
// For each dist unit
6765
for (SmallVector<int64_t> unitOffs :
68-
StaticTileOffsetRange(sizePerWg, distUnit)) {
66+
StaticTileOffsetRange(srcShape, distUnitShape)) {
67+
// Get dist unit offset within `srcShape`.
6968
SmallVector<Value> base =
7069
llvm::map_to_vector(unitOffs, [&](int64_t d) -> Value {
7170
return arith::ConstantIndexOp::create(builder, loc, d);
7271
});
73-
74-
SmallVector<Value> adds = llvm::map_to_vector(
75-
llvm::zip_equal(base, localOffsets), [&](const auto &t) -> Value {
76-
return builder.createOrFold<arith::AddIOp>(loc, std::get<0>(t),
77-
std::get<1>(t));
78-
});
79-
72+
// Calculate `subShape` offset within `srcShape`.
73+
SmallVector<Value> adds =
74+
llvm::map_to_vector(llvm::zip_equal(base, distUnitLocalOffset),
75+
[&](const auto &t) -> Value {
76+
return builder.createOrFold<arith::AddIOp>(
77+
loc, std::get<0>(t), std::get<1>(t));
78+
});
79+
// Do not go beyond `srcShape` bounds.
8080
SmallVector<Value> mods = llvm::map_to_vector(
81-
llvm::zip_equal(adds, sizePerWg), [&](const auto &t) -> Value {
81+
llvm::zip_equal(adds, srcShape), [&](const auto &t) -> Value {
8282
return builder.createOrFold<index::RemUOp>(
8383
loc, std::get<0>(t),
8484
arith::ConstantIndexOp::create(builder, loc, std::get<1>(t)));
@@ -268,12 +268,8 @@ LayoutAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
268268
}
269269

270270
FailureOr<SmallVector<Value>>
271-
LayoutAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc,
272-
Value linearId) {
273-
// delinearizeSubgroupId is only available for
274-
// workgroup-level layout attribute
275-
if (!isForWorkgroup())
276-
return failure();
271+
LayoutAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId,
272+
xegpu::DistributionLevel idLevel) {
277273

278274
// TODO: handle order attribute
279275
auto hasDefaultOrder = [&]() {
@@ -283,41 +279,53 @@ LayoutAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc,
283279
};
284280
if (!hasDefaultOrder())
285281
return mlir::emitError(loc, "order attribute is currently not supported.");
286-
287-
auto dims =
288-
llvm::map_to_vector(getEffectiveSgLayoutAsInt(), [&](int64_t d) -> Value {
289-
return builder.createOrFold<arith::ConstantIndexOp>(loc, d);
290-
});
282+
SmallVector<int64_t> layout;
283+
if (idLevel == xegpu::DistributionLevel::SG) {
284+
layout = getEffectiveSgLayoutAsInt();
285+
} else if (idLevel == xegpu::DistributionLevel::WI) {
286+
layout = getEffectiveLaneLayoutAsInt();
287+
} else {
288+
return failure();
289+
}
290+
auto dims = llvm::map_to_vector(layout, [&](int64_t d) -> Value {
291+
return builder.createOrFold<arith::ConstantIndexOp>(loc, d);
292+
});
291293

292294
return affine::delinearizeIndex(builder, loc, linearId, dims);
293295
}
294296

295-
/// Implements DistributeLayoutAttr::getOffsets to generate
297+
/// Implements DistributeLayoutAttr::computeDistributedCoords to generate
296298
/// instructions for computing multi-dimensional offsets when distributed by
297299
/// LayoutAttr.
298300
FailureOr<SmallVector<SmallVector<Value>>>
299-
LayoutAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
300-
ArrayRef<int64_t> shape) {
301-
if (!isForWorkgroup())
301+
LayoutAttr::computeDistributedCoords(OpBuilder &builder, Location loc,
302+
Value linearId, ArrayRef<int64_t> shape,
303+
xegpu::DistributionLevel targetLevel) {
304+
SmallVector<int64_t> layout;
305+
SmallVector<int64_t> subShape;
306+
if (targetLevel == DistributionLevel::SG) {
307+
layout = getEffectiveSgLayoutAsInt();
308+
subShape = getEffectiveSgDataAsInt();
309+
} else if (targetLevel == DistributionLevel::WI) {
310+
layout = getEffectiveLaneLayoutAsInt();
311+
subShape = getEffectiveLaneDataAsInt();
312+
} else {
302313
return failure();
303-
304-
SmallVector<int64_t> sgLayout = getEffectiveSgLayoutAsInt();
305-
SmallVector<int64_t> sgShape = getEffectiveSgDataAsInt();
306-
if (sgShape.empty()) {
307-
if (auto derivedShape = computeShapeRatio(shape, sgLayout))
308-
sgShape = derivedShape.value();
314+
}
315+
if (subShape.empty()) {
316+
if (auto derivedShape = computeShapeRatio(shape, layout))
317+
subShape = derivedShape.value();
309318
else
310319
return failure();
311320
}
312321

313322
// delinearize Ids
314-
auto maybeIds = delinearizeSubgroupId(builder, loc, linearId);
323+
auto maybeIds = delinearizeId(builder, loc, linearId, targetLevel);
315324
if (failed(maybeIds))
316325
return failure();
317-
SmallVector<Value> sgIds = *maybeIds;
326+
SmallVector<Value> ids = *maybeIds;
318327

319-
return genOffsetsComputingInsts(builder, loc, sgIds, sgLayout, sgShape,
320-
shape);
328+
return genOffsets(builder, loc, ids, layout, subShape, shape);
321329
}
322330

323331
//===----------------------------------------------------------------------===//
@@ -371,34 +379,45 @@ SliceAttr SliceAttr::flatten() const {
371379
}
372380

373381
FailureOr<SmallVector<Value>>
374-
SliceAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc,
375-
Value linearId) {
382+
SliceAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId,
383+
xegpu::DistributionLevel level) {
376384
SliceAttr attr = flatten();
377385
auto parent = dyn_cast<LayoutAttr>(attr.getParent());
378-
return parent.delinearizeSubgroupId(builder, loc, linearId);
386+
return parent.delinearizeId(builder, loc, linearId, level);
379387
}
380388

381-
/// Implements DistributeLayoutAttr::getOffsets to generate
382-
/// instructions for computing multi-dimensional offsets when distributed by
383-
/// SliceAttr.
389+
// Implements DistributeLayoutAttr::computeDistributedCoords to generate
390+
// instructions for computing multi-dimensional offsets when distributed by
391+
// LayoutAttr.
384392
FailureOr<SmallVector<SmallVector<Value>>>
385-
SliceAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
386-
ArrayRef<int64_t> shape) {
393+
SliceAttr::computeDistributedCoords(OpBuilder &builder, Location loc,
394+
Value linearId, ArrayRef<int64_t> shape,
395+
xegpu::DistributionLevel targetLevel) {
387396
assert(getRank() == static_cast<int64_t>(shape.size()) && "invalid shape.");
388397
if (!isForWorkgroup())
389398
return failure();
390399

391-
SmallVector<int64_t> sgLayout = getEffectiveSgLayoutAsInt();
392-
SmallVector<int64_t> sgShape = getEffectiveSgDataAsInt();
393-
if (sgShape.empty()) {
394-
if (auto derivedShape = computeShapeRatio(shape, sgLayout))
395-
sgShape = derivedShape.value();
400+
SmallVector<int64_t> layout;
401+
SmallVector<int64_t> subShape;
402+
if (targetLevel == DistributionLevel::SG) {
403+
layout = getEffectiveSgLayoutAsInt();
404+
subShape = getEffectiveSgDataAsInt();
405+
} else if (targetLevel == DistributionLevel::WI) {
406+
layout = getEffectiveLaneLayoutAsInt();
407+
subShape = getEffectiveLaneDataAsInt();
408+
} else {
409+
return failure();
410+
}
411+
412+
if (subShape.empty()) {
413+
if (auto derivedShape = computeShapeRatio(shape, layout))
414+
subShape = derivedShape.value();
396415
else
397416
return failure();
398417
}
399418

400419
// delinearize Ids
401-
auto maybeIds = delinearizeSubgroupId(builder, loc, linearId);
420+
auto maybeIds = delinearizeId(builder, loc, linearId, targetLevel);
402421
if (failed(maybeIds))
403422
return failure();
404423

@@ -408,8 +427,7 @@ SliceAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
408427
SmallVector<Value> sgIds =
409428
XeGPUDialect::slice(ArrayRef<Value>(*maybeIds), dims);
410429

411-
return genOffsetsComputingInsts(builder, loc, sgIds, sgLayout, sgShape,
412-
shape);
430+
return genOffsets(builder, loc, sgIds, layout, subShape, shape);
413431
}
414432

415433
bool SliceAttr::isSliceOf(const xegpu::DistributeLayoutAttr &other) {

0 commit comments

Comments
 (0)