Skip to content

Commit 2fb9ac7

Browse files
committed
refactor
1 parent 4f93bcb commit 2fb9ac7

File tree

1 file changed

+16
-16
lines changed

1 file changed

+16
-16
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -34,24 +34,19 @@ using namespace mlir;
3434

3535
namespace {
3636

37-
// Check if there is sg id range attached to the scf.if op.
38-
static bool isSgIdRangeSpecified(Operation *op, int64_t &startOfRange,
39-
int64_t &endOfRange) {
37+
// Retrieve the RangeAttr if it is specified.
38+
static xegpu::RangeAttr getRangeSpecAttr(Operation *op) {
4039
Operation *parent = op->getParentOp();
41-
// Find the outermost scf::IfOp with xegpu.sg_id_range.
4240
while (parent) {
4341
if (auto ifOp = dyn_cast<scf::IfOp>(parent)) {
4442
if (auto attr = llvm::dyn_cast_or_null<xegpu::RangeAttr>(
4543
ifOp->getAttr("sg_id_range"))) {
46-
startOfRange = attr.getStart().getInt();
47-
endOfRange = attr.getEnd().getInt();
48-
break;
44+
return attr;
4945
}
5046
}
5147
parent = parent->getParentOp();
5248
}
53-
// Return false if startOfRange is 0
54-
return (startOfRange > 0 && endOfRange > startOfRange);
49+
return {};
5550
}
5651

5752
static std::pair<SmallVector<int64_t>, int>
@@ -101,16 +96,21 @@ genOffsetsList(ConversionPatternRewriter &rewriter, OpType op,
10196

10297
Value sgId = rewriter.create<gpu::SubgroupIdOp>(loc, /*upper_bound=*/nullptr);
10398

104-
// adjust the linearId if the range specifier is present
105-
int64_t startOfRange = -1, endOfRange = -1;
106-
bool sgIdRangeSpecified = isSgIdRangeSpecified(op, startOfRange, endOfRange);
107-
if (sgIdRangeSpecified) {
99+
// verify and adjust the sgId if the range specifier is present
100+
xegpu::RangeAttr sgIdRange = getRangeSpecAttr(op);
101+
if (sgIdRange) {
102+
int64_t startOfRange = sgIdRange.getStart().getInt();
103+
int64_t endOfRange = sgIdRange.getEnd().getInt();
104+
// verify the RangeAttr against the layout attribute
108105
if (layout.getNumSubgroups() != endOfRange - startOfRange)
109106
return rewriter.notifyMatchFailure(
110107
op, "sg_layout size must match the sg_id_range");
111-
Value startOfRangeVal =
112-
rewriter.create<arith::ConstantIndexOp>(loc, startOfRange);
113-
sgId = rewriter.create<index::SubOp>(loc, sgId, startOfRangeVal);
108+
// adjust the sgId if necessary
109+
if (startOfRange > 0) {
110+
Value startOfRangeVal =
111+
rewriter.create<arith::ConstantIndexOp>(loc, startOfRange);
112+
sgId = rewriter.create<index::SubOp>(loc, sgId, startOfRangeVal);
113+
}
114114
}
115115

116116
// Compute the list of subgroup-relative offsets for sub-tensors or sub-memory

0 commit comments

Comments
 (0)