Skip to content

Commit 5965b54

Browse files
committed
Remove DistributionLevel enum
1 parent 3c4a5aa commit 5965b54

File tree

5 files changed

+34
-49
lines changed

5 files changed

+34
-49
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td

Lines changed: 11 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -167,16 +167,6 @@ def XeGPU_FenceScope: I32EnumAttr<"FenceScope",
167167
let cppNamespace = "::mlir::xegpu";
168168
}
169169

170-
def XeGPU_WGLevel: I32EnumAttrCase<"WG", 0, "wg">;
171-
def XeGPU_SGLevel: I32EnumAttrCase<"SG", 1, "sg">;
172-
def XeGPU_WILevel: I32EnumAttrCase<"WI", 2, "wi">;
173-
def XeGPU_DistributionLevel: I32EnumAttr<"DistributionLevel",
174-
"Specify target level for offsets distribution utility.",
175-
[XeGPU_WGLevel, XeGPU_SGLevel, XeGPU_WILevel]> {
176-
let genSpecializedAttr = 0;
177-
let cppNamespace = "::mlir::xegpu";
178-
}
179-
180170
def XeGPU_FenceScopeAttr:
181171
EnumAttr<XeGPU_Dialect, XeGPU_FenceScope, "fence_scope"> {
182172
let summary = [{Describes the scope of fence.
@@ -234,17 +224,17 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
234224
"xegpu::DistributeLayoutAttr",
235225
"dropInstData">,
236226
InterfaceMethod<[{Delinearizes a linear ID into its multidimensional
237-
indices based on the effective `level` layout.}],
227+
indices based on the effective layout level.}],
238228
"FailureOr<SmallVector<Value>>",
239229
"delinearizeId",
240-
(ins "OpBuilder &": $builder, "Location":$loc, "Value":$linearId, "xegpu::DistributionLevel": $level)>,
230+
(ins "OpBuilder &": $builder, "Location":$loc, "Value":$linearId)>,
241231
InterfaceMethod<[{Generates instructions to compute multidimensional offsets for dist units
242-
assigned to a `level` identified by linearId. The shape parameter
243-
represents the higher-level problem size. Each `level` may access
232+
assigned to a level identified by linearId. The shape parameter
233+
represents the higher-level problem size. Each level may access
244234
multiple blocks according to round-robin distribution rules.}],
245235
"FailureOr<SmallVector<SmallVector<Value>>>",
246236
"computeDistributedOffsets",
247-
(ins "OpBuilder &": $builder, "Location":$loc, "Value":$linearId, "ArrayRef<int64_t>":$shape, "xegpu::DistributionLevel": $level)>,
237+
(ins "OpBuilder &": $builder, "Location":$loc, "Value":$linearId, "ArrayRef<int64_t>":$shape)>,
248238
InterfaceMethod</*desc=*/[{Check if this layout can be achieved by applying a transpose
249239
to some other layout according to given permutation of (0...n-1).}],
250240
/*retTy=*/"bool",
@@ -487,16 +477,16 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> {
487477
}
488478

489479
/// Delinearizes a linear ID into its multidimensional indices
490-
/// based on the effective `level` layout.
480+
/// based on the effective level of the layout.
491481
FailureOr<SmallVector<Value>>
492-
delinearizeId(OpBuilder &builder, Location loc, Value linearId, xegpu::DistributionLevel level);
482+
delinearizeId(OpBuilder &builder, Location loc, Value linearId);
493483

494484
/// Generates instructions to compute multidimensional offsets for dist units
495-
/// assigned to a `level` identified by linearId. The shape parameter
485+
/// assigned to a level identified by linearId. The shape parameter
496486
/// represents the higher-level problem size. Each `level` may access
497487
/// multiple blocks according to round-robin distribution rules.
498488
FailureOr<SmallVector<SmallVector<Value>>>
499-
computeDistributedOffsets(OpBuilder &builder, Location loc, Value linearId, ArrayRef<int64_t> shape, xegpu::DistributionLevel level);
489+
computeDistributedOffsets(OpBuilder &builder, Location loc, Value linearId, ArrayRef<int64_t> shape);
500490

501491
/// Check if this is slice of some other layout.
502492
bool isSliceOf(const xegpu::DistributeLayoutAttr &other) { return false; }
@@ -653,15 +643,15 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {
653643
/// Delinearizes a linear subgroup ID into its multidimensional indices
654644
/// based on the effective subgroup layout.
655645
FailureOr<SmallVector<Value>>
656-
delinearizeId(OpBuilder &builder, Location loc, Value linearId, xegpu::DistributionLevel level);
646+
delinearizeId(OpBuilder &builder, Location loc, Value linearId);
657647

658648
/// Generates instructions to compute multidimensional offsets for blocks
659649
/// assigned to a subgroup identified by linearId. The shape parameter
660650
/// represents the workgroup-level problem size. Each subgroup may access
661651
/// multiple blocks according to round-robin distribution rules.
662652

663653
FailureOr<SmallVector<SmallVector<Value>>>
664-
computeDistributedOffsets(OpBuilder &builder, Location loc, Value linearId,ArrayRef<int64_t> shape, xegpu::DistributionLevel level);
654+
computeDistributedOffsets(OpBuilder &builder, Location loc, Value linearId,ArrayRef<int64_t> shape);
665655

666656
/// Check if this is slice of some other layout.
667657
bool isSliceOf(const xegpu::DistributeLayoutAttr &other);

mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -273,8 +273,7 @@ LayoutAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
273273
}
274274

275275
FailureOr<SmallVector<Value>>
276-
LayoutAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId,
277-
xegpu::DistributionLevel idLevel) {
276+
LayoutAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId) {
278277

279278
// TODO: handle order attribute
280279
auto hasDefaultOrder = [&]() {
@@ -285,9 +284,9 @@ LayoutAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId,
285284
if (!hasDefaultOrder())
286285
return mlir::emitError(loc, "order attribute is currently not supported.");
287286
SmallVector<int64_t> layout;
288-
if (idLevel == xegpu::DistributionLevel::SG) {
287+
if (isForWorkgroup()) {
289288
layout = getEffectiveSgLayoutAsInt();
290-
} else if (idLevel == xegpu::DistributionLevel::WI) {
289+
} else if (isForSubgroup()) {
291290
layout = getEffectiveLaneLayoutAsInt();
292291
} else {
293292
return failure();
@@ -304,14 +303,13 @@ LayoutAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId,
304303
/// LayoutAttr.
305304
FailureOr<SmallVector<SmallVector<Value>>>
306305
LayoutAttr::computeDistributedOffsets(OpBuilder &builder, Location loc,
307-
Value linearId, ArrayRef<int64_t> shape,
308-
xegpu::DistributionLevel targetLevel) {
306+
Value linearId, ArrayRef<int64_t> shape) {
309307
SmallVector<int64_t> layout;
310308
SmallVector<int64_t> subShape;
311-
if (targetLevel == DistributionLevel::SG) {
309+
if (isForWorkgroup()) {
312310
layout = getEffectiveSgLayoutAsInt();
313311
subShape = getEffectiveSgDataAsInt();
314-
} else if (targetLevel == DistributionLevel::WI) {
312+
} else if (isForSubgroup()) {
315313
layout = getEffectiveLaneLayoutAsInt();
316314
subShape = getEffectiveLaneDataAsInt();
317315
} else {
@@ -325,7 +323,7 @@ LayoutAttr::computeDistributedOffsets(OpBuilder &builder, Location loc,
325323
}
326324

327325
// delinearize Ids
328-
auto maybeIds = delinearizeId(builder, loc, linearId, targetLevel);
326+
auto maybeIds = delinearizeId(builder, loc, linearId);
329327
if (failed(maybeIds))
330328
return failure();
331329
SmallVector<Value> ids = *maybeIds;
@@ -384,30 +382,28 @@ SliceAttr SliceAttr::flatten() const {
384382
}
385383

386384
FailureOr<SmallVector<Value>>
387-
SliceAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId,
388-
xegpu::DistributionLevel level) {
385+
SliceAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId) {
389386
SliceAttr attr = flatten();
390387
auto parent = dyn_cast<LayoutAttr>(attr.getParent());
391-
return parent.delinearizeId(builder, loc, linearId, level);
388+
return parent.delinearizeId(builder, loc, linearId);
392389
}
393390

394391
// Implements DistributeLayoutAttr::computeDistributedOffsets to generate
395392
// instructions for computing multi-dimensional offsets when distributed by
396393
// LayoutAttr.
397394
FailureOr<SmallVector<SmallVector<Value>>>
398395
SliceAttr::computeDistributedOffsets(OpBuilder &builder, Location loc,
399-
Value linearId, ArrayRef<int64_t> shape,
400-
xegpu::DistributionLevel targetLevel) {
396+
Value linearId, ArrayRef<int64_t> shape) {
401397
assert(getRank() == static_cast<int64_t>(shape.size()) && "invalid shape.");
402398
if (!isForWorkgroup())
403399
return failure();
404400

405401
SmallVector<int64_t> layout;
406402
SmallVector<int64_t> subShape;
407-
if (targetLevel == DistributionLevel::SG) {
403+
if (isForWorkgroup()) {
408404
layout = getEffectiveSgLayoutAsInt();
409405
subShape = getEffectiveSgDataAsInt();
410-
} else if (targetLevel == DistributionLevel::WI) {
406+
} else if (isForSubgroup()) {
411407
layout = getEffectiveLaneLayoutAsInt();
412408
subShape = getEffectiveLaneDataAsInt();
413409
} else {
@@ -422,7 +418,7 @@ SliceAttr::computeDistributedOffsets(OpBuilder &builder, Location loc,
422418
}
423419

424420
// delinearize Ids
425-
auto maybeIds = delinearizeId(builder, loc, linearId, targetLevel);
421+
auto maybeIds = delinearizeId(builder, loc, linearId);
426422
if (failed(maybeIds))
427423
return failure();
428424

mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -911,9 +911,8 @@ static SmallVector<Value> computeDistributedOffsetsForMatrixOp(
911911
PatternRewriter &rewriter, Location loc, xegpu::DistributeLayoutAttr layout,
912912
Value laneId, ArrayRef<int64_t> payloadShape, ValueRange origOffsets) {
913913
SmallVector<Value> newOffsets;
914-
;
915-
auto maybeDescOffsets = layout.computeDistributedOffsets(
916-
rewriter, loc, laneId, payloadShape, xegpu::DistributionLevel::WI);
914+
auto maybeDescOffsets =
915+
layout.computeDistributedOffsets(rewriter, loc, laneId, payloadShape);
917916
if (failed(maybeDescOffsets))
918917
return {};
919918
assert(maybeDescOffsets.value().size() == 1 &&

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -114,8 +114,8 @@ genOffsetsList(ConversionPatternRewriter &rewriter, OpType op,
114114
// Compute the list of subgroup-relative offsets for sub-tensors or sub-memory
115115
// descriptors to be accessed, based on the layout information.
116116
ArrayRef<int64_t> wgShape = op.getDataShape();
117-
auto maybeDescOffsets = layout.computeDistributedOffsets(
118-
rewriter, loc, sgId, wgShape, xegpu::DistributionLevel::SG);
117+
auto maybeDescOffsets =
118+
layout.computeDistributedOffsets(rewriter, loc, sgId, wgShape);
119119
if (failed(maybeDescOffsets))
120120
return failure();
121121

@@ -831,8 +831,8 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
831831
// Get subgroup id
832832
Value sgId =
833833
gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
834-
auto sgOffsets = layout.computeDistributedOffsets(
835-
rewriter, loc, sgId, wgShape, xegpu::DistributionLevel::SG);
834+
auto sgOffsets =
835+
layout.computeDistributedOffsets(rewriter, loc, sgId, wgShape);
836836
if (failed(sgOffsets))
837837
return failure();
838838

@@ -1053,8 +1053,8 @@ struct WgToSgVectorStepOp : public OpConversionPattern<vector::StepOp> {
10531053

10541054
Value sgId =
10551055
gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
1056-
auto sgOffsets = layout.computeDistributedOffsets(
1057-
rewriter, loc, sgId, wgShape, xegpu::DistributionLevel::SG);
1056+
auto sgOffsets =
1057+
layout.computeDistributedOffsets(rewriter, loc, sgId, wgShape);
10581058
if (failed(sgOffsets))
10591059
return failure();
10601060

mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,8 +200,8 @@ class TestStepOpPattern : public OpConversionPattern<vector::StepOp> {
200200

201201
Value sgId =
202202
gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
203-
auto maybeOffsets = sliceAttr.computeDistributedOffsets(
204-
rewriter, loc, sgId, wgShape, xegpu::DistributionLevel::SG);
203+
auto maybeOffsets =
204+
sliceAttr.computeDistributedOffsets(rewriter, loc, sgId, wgShape);
205205
if (failed(maybeOffsets))
206206
return failure();
207207

0 commit comments

Comments
 (0)