@@ -34,24 +34,19 @@ using namespace mlir;
3434
3535namespace {
3636
37- // Check if there is sg id range attached to the scf.if op.
38- static bool isSgIdRangeSpecified (Operation *op, int64_t &startOfRange,
39- int64_t &endOfRange) {
37+ // Retrieve the RangeAttr if it is specified.
38+ static xegpu::RangeAttr getRangeSpecAttr (Operation *op) {
4039 Operation *parent = op->getParentOp ();
41- // Find the outermost scf::IfOp with xegpu.sg_id_range.
4240 while (parent) {
4341 if (auto ifOp = dyn_cast<scf::IfOp>(parent)) {
4442 if (auto attr = llvm::dyn_cast_or_null<xegpu::RangeAttr>(
4543 ifOp->getAttr (" sg_id_range" ))) {
46- startOfRange = attr.getStart ().getInt ();
47- endOfRange = attr.getEnd ().getInt ();
48- break ;
44+ return attr;
4945 }
5046 }
5147 parent = parent->getParentOp ();
5248 }
53- // Return false if startOfRange is 0
54- return (startOfRange > 0 && endOfRange > startOfRange);
49+ return {};
5550}
5651
5752static std::pair<SmallVector<int64_t >, int >
@@ -101,16 +96,21 @@ genOffsetsList(ConversionPatternRewriter &rewriter, OpType op,
10196
10297 Value sgId = rewriter.create <gpu::SubgroupIdOp>(loc, /* upper_bound=*/ nullptr );
10398
104- // adjust the linearId if the range specifier is present
105- int64_t startOfRange = -1 , endOfRange = -1 ;
106- bool sgIdRangeSpecified = isSgIdRangeSpecified (op, startOfRange, endOfRange);
107- if (sgIdRangeSpecified) {
99+ // verify and adjust the sgId if the range specifier is present
100+ xegpu::RangeAttr sgIdRange = getRangeSpecAttr (op);
101+ if (sgIdRange) {
102+ int64_t startOfRange = sgIdRange.getStart ().getInt ();
103+ int64_t endOfRange = sgIdRange.getEnd ().getInt ();
104+ // verify the RangeAttr against the layout attribute
108105 if (layout.getNumSubgroups () != endOfRange - startOfRange)
109106 return rewriter.notifyMatchFailure (
110107 op, " sg_layout size must match the sg_id_range" );
111- Value startOfRangeVal =
112- rewriter.create <arith::ConstantIndexOp>(loc, startOfRange);
113- sgId = rewriter.create <index::SubOp>(loc, sgId, startOfRangeVal);
108+ // adjust the sgId if necessary
109+ if (startOfRange > 0 ) {
110+ Value startOfRangeVal =
111+ rewriter.create <arith::ConstantIndexOp>(loc, startOfRange);
112+ sgId = rewriter.create <index::SubOp>(loc, sgId, startOfRangeVal);
113+ }
114114 }
115115
116116 // Compute the list of subgroup-relative offsets for sub-tensors or sub-memory
0 commit comments