Skip to content

Commit 1c1d525

Browse files
authored
[mlir][omp] Improve canonloop/iv naming (#159773)
Improve the automatic naming of variables defined by the `omp.canonical_loop` operation: 1. The iteration variable gets a name consistent with the cli variable 2. Instead of appending `_s0` for each nesting level, shorten it to `_d<num>` for a perfectly nested loop at depth `<num>` 3. Do not add any suffix to the top-level loop if it is the only top-level loop
1 parent 119cdf7 commit 1c1d525

File tree

3 files changed

+417
-100
lines changed

3 files changed

+417
-100
lines changed

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

Lines changed: 229 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,232 @@ struct LLVMPointerPointerLikeModel
7777
};
7878
} // namespace
7979

80+
/// Generate a name of a canonical loop nest of the format
81+
/// `<prefix>(_r<idx>_s<idx>)*`. Hereby, `_r<idx>` identifies the region
82+
/// argument index of an operation that has multiple regions, if the operation
83+
/// has multiple regions.
84+
/// `_s<idx>` identifies the position of an operation within a region, where
85+
/// only operations that may potentially contain loops ("container operations"
86+
/// i.e. have region arguments) are counted. Again, it is omitted if there is
87+
/// only one such operation in a region. If there are canonical loops nested
88+
/// inside each other, also may also use the format `_d<num>` where <num> is the
89+
/// nesting depth of the loop.
90+
///
91+
/// The generated name is a best-effort to make canonical loop unique within an
92+
/// SSA namespace. This also means that regions with IsolatedFromAbove property
93+
/// do not consider any parents or siblings.
94+
static std::string generateLoopNestingName(StringRef prefix,
95+
CanonicalLoopOp op) {
96+
struct Component {
97+
/// If true, this component describes a region operand of an operation (the
98+
/// operand's owner) If false, this component describes an operation located
99+
/// in a parent region
100+
bool isRegionArgOfOp;
101+
bool skip = false;
102+
bool isUnique = false;
103+
104+
size_t idx;
105+
Operation *op;
106+
Region *parentRegion;
107+
size_t loopDepth;
108+
109+
Operation *&getOwnerOp() {
110+
assert(isRegionArgOfOp && "Must describe a region operand");
111+
return op;
112+
}
113+
size_t &getArgIdx() {
114+
assert(isRegionArgOfOp && "Must describe a region operand");
115+
return idx;
116+
}
117+
118+
Operation *&getContainerOp() {
119+
assert(!isRegionArgOfOp && "Must describe a operation of a region");
120+
return op;
121+
}
122+
size_t &getOpPos() {
123+
assert(!isRegionArgOfOp && "Must describe a operation of a region");
124+
return idx;
125+
}
126+
bool isLoopOp() const {
127+
assert(!isRegionArgOfOp && "Must describe a operation of a region");
128+
return isa<CanonicalLoopOp>(op);
129+
}
130+
Region *&getParentRegion() {
131+
assert(!isRegionArgOfOp && "Must describe a operation of a region");
132+
return parentRegion;
133+
}
134+
size_t &getLoopDepth() {
135+
assert(!isRegionArgOfOp && "Must describe a operation of a region");
136+
return loopDepth;
137+
}
138+
139+
void skipIf(bool v = true) { skip = skip || v; }
140+
};
141+
142+
// List of ancestors, from inner to outer.
143+
// Alternates between
144+
// * region argument of an operation
145+
// * operation within a region
146+
SmallVector<Component> components;
147+
148+
// Gather a list of parent regions and operations, and the position within
149+
// their parent
150+
Operation *o = op.getOperation();
151+
while (o) {
152+
// Operation within a region
153+
Region *r = o->getParentRegion();
154+
if (!r)
155+
break;
156+
157+
llvm::ReversePostOrderTraversal<Block *> traversal(&r->getBlocks().front());
158+
size_t idx = 0;
159+
bool found = false;
160+
size_t sequentialIdx = -1;
161+
bool isOnlyContainerOp = true;
162+
for (Block *b : traversal) {
163+
for (Operation &op : *b) {
164+
if (&op == o && !found) {
165+
sequentialIdx = idx;
166+
found = true;
167+
}
168+
if (op.getNumRegions()) {
169+
idx += 1;
170+
if (idx > 1)
171+
isOnlyContainerOp = false;
172+
}
173+
if (found && !isOnlyContainerOp)
174+
break;
175+
}
176+
}
177+
178+
Component &containerOpInRegion = components.emplace_back();
179+
containerOpInRegion.isRegionArgOfOp = false;
180+
containerOpInRegion.isUnique = isOnlyContainerOp;
181+
containerOpInRegion.getContainerOp() = o;
182+
containerOpInRegion.getOpPos() = sequentialIdx;
183+
containerOpInRegion.getParentRegion() = r;
184+
185+
Operation *parent = r->getParentOp();
186+
187+
// Region argument of an operation
188+
Component &regionArgOfOperation = components.emplace_back();
189+
regionArgOfOperation.isRegionArgOfOp = true;
190+
regionArgOfOperation.isUnique = true;
191+
regionArgOfOperation.getArgIdx() = 0;
192+
regionArgOfOperation.getOwnerOp() = parent;
193+
194+
// The IsolatedFromAbove trait of the parent operation implies that each
195+
// individual region argument has its own separate namespace, so no
196+
// ambiguity.
197+
if (!parent || parent->hasTrait<mlir::OpTrait::IsIsolatedFromAbove>())
198+
break;
199+
200+
// Component only needed if operation has multiple region operands. Region
201+
// arguments may be optional, but we currently do not consider this.
202+
if (parent->getRegions().size() > 1) {
203+
auto getRegionIndex = [](Operation *o, Region *r) {
204+
for (auto [idx, region] : llvm::enumerate(o->getRegions())) {
205+
if (&region == r)
206+
return idx;
207+
}
208+
llvm_unreachable("Region not child of its parent operation");
209+
};
210+
regionArgOfOperation.isUnique = false;
211+
regionArgOfOperation.getArgIdx() = getRegionIndex(parent, r);
212+
}
213+
214+
// next parent
215+
o = parent;
216+
}
217+
218+
// Determine whether a region-argument component is not needed
219+
for (Component &c : components)
220+
c.skipIf(c.isRegionArgOfOp && c.isUnique);
221+
222+
// Find runs of nested loops and determine each loop's depth in the loop nest
223+
size_t numSurroundingLoops = 0;
224+
for (Component &c : llvm::reverse(components)) {
225+
if (c.skip)
226+
continue;
227+
228+
// non-skipped multi-argument operands interrupt the loop nest
229+
if (c.isRegionArgOfOp) {
230+
numSurroundingLoops = 0;
231+
continue;
232+
}
233+
234+
// Multiple loops in a region means each of them is the outermost loop of a
235+
// new loop nest
236+
if (!c.isUnique)
237+
numSurroundingLoops = 0;
238+
239+
c.getLoopDepth() = numSurroundingLoops;
240+
241+
// Next loop is surrounded by one more loop
242+
if (isa<CanonicalLoopOp>(c.getContainerOp()))
243+
numSurroundingLoops += 1;
244+
}
245+
246+
// In loop nests, skip all but the innermost loop that contains the depth
247+
// number
248+
bool isLoopNest = false;
249+
for (Component &c : components) {
250+
if (c.skip || c.isRegionArgOfOp)
251+
continue;
252+
253+
if (!isLoopNest && c.getLoopDepth() >= 1) {
254+
// Innermost loop of a loop nest of at least two loops
255+
isLoopNest = true;
256+
} else if (isLoopNest) {
257+
// Non-innermost loop of a loop nest
258+
c.skipIf(c.isUnique);
259+
260+
// If there is no surrounding loop left, this must have been the outermost
261+
// loop; leave loop-nest mode for the next iteration
262+
if (c.getLoopDepth() == 0)
263+
isLoopNest = false;
264+
}
265+
}
266+
267+
// Skip non-loop unambiguous regions (but they should interrupt loop nests, so
268+
// we mark them as skipped only after computing loop nests)
269+
for (Component &c : components)
270+
c.skipIf(!c.isRegionArgOfOp && c.isUnique &&
271+
!isa<CanonicalLoopOp>(c.getContainerOp()));
272+
273+
// Components can be skipped if they are already disambiguated by their parent
274+
// (or does not have a parent)
275+
bool newRegion = true;
276+
for (Component &c : llvm::reverse(components)) {
277+
c.skipIf(newRegion && c.isUnique);
278+
279+
// non-skipped components disambiguate unique children
280+
if (!c.skip)
281+
newRegion = true;
282+
283+
// ...except canonical loops that need a suffix for each nest
284+
if (!c.isRegionArgOfOp && c.getContainerOp())
285+
newRegion = false;
286+
}
287+
288+
// Compile the nesting name string
289+
SmallString<64> Name{prefix};
290+
llvm::raw_svector_ostream NameOS(Name);
291+
for (auto &c : llvm::reverse(components)) {
292+
if (c.skip)
293+
continue;
294+
295+
if (c.isRegionArgOfOp)
296+
NameOS << "_r" << c.getArgIdx();
297+
else if (c.getLoopDepth() >= 1)
298+
NameOS << "_d" << c.getLoopDepth();
299+
else
300+
NameOS << "_s" << c.getOpPos();
301+
}
302+
303+
return NameOS.str().str();
304+
}
305+
80306
void OpenMPDialect::initialize() {
81307
addOperations<
82308
#define GET_OP_LIST
@@ -3172,67 +3398,7 @@ void NewCliOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
31723398
cliName =
31733399
TypeSwitch<Operation *, std::string>(gen->getOwner())
31743400
.Case([&](CanonicalLoopOp op) {
3175-
// Find the canonical loop nesting: For each ancestor add a
3176-
// "+_r<idx>" suffix (in reverse order)
3177-
SmallVector<std::string> components;
3178-
Operation *o = op.getOperation();
3179-
while (o) {
3180-
if (o->hasTrait<mlir::OpTrait::IsIsolatedFromAbove>())
3181-
break;
3182-
3183-
Region *r = o->getParentRegion();
3184-
if (!r)
3185-
break;
3186-
3187-
auto getSequentialIndex = [](Region *r, Operation *o) {
3188-
llvm::ReversePostOrderTraversal<Block *> traversal(
3189-
&r->getBlocks().front());
3190-
size_t idx = 0;
3191-
for (Block *b : traversal) {
3192-
for (Operation &op : *b) {
3193-
if (&op == o)
3194-
return idx;
3195-
// Only consider operations that are containers as
3196-
// possible children
3197-
if (!op.getRegions().empty())
3198-
idx += 1;
3199-
}
3200-
}
3201-
llvm_unreachable("Operation not part of the region");
3202-
};
3203-
size_t sequentialIdx = getSequentialIndex(r, o);
3204-
components.push_back(("s" + Twine(sequentialIdx)).str());
3205-
3206-
Operation *parent = r->getParentOp();
3207-
if (!parent)
3208-
break;
3209-
3210-
// If the operation has more than one region, also count in
3211-
// which of the regions
3212-
if (parent->getRegions().size() > 1) {
3213-
auto getRegionIndex = [](Operation *o, Region *r) {
3214-
for (auto [idx, region] :
3215-
llvm::enumerate(o->getRegions())) {
3216-
if (&region == r)
3217-
return idx;
3218-
}
3219-
llvm_unreachable("Region not child its parent operation");
3220-
};
3221-
size_t regionIdx = getRegionIndex(parent, r);
3222-
components.push_back(("r" + Twine(regionIdx)).str());
3223-
}
3224-
3225-
// next parent
3226-
o = parent;
3227-
}
3228-
3229-
SmallString<64> Name("canonloop");
3230-
for (const std::string &s : reverse(components)) {
3231-
Name += '_';
3232-
Name += s;
3233-
}
3234-
3235-
return Name;
3401+
return generateLoopNestingName("canonloop", op);
32363402
})
32373403
.Case([&](UnrollHeuristicOp op) -> std::string {
32383404
llvm_unreachable("heuristic unrolling does not generate a loop");
@@ -3323,7 +3489,8 @@ void CanonicalLoopOp::getAsmBlockNames(OpAsmSetBlockNameFn setNameFn) {
33233489

33243490
void CanonicalLoopOp::getAsmBlockArgumentNames(Region &region,
33253491
OpAsmSetValueNameFn setNameFn) {
3326-
setNameFn(region.getArgument(0), "iv");
3492+
std::string ivName = generateLoopNestingName("iv", *this);
3493+
setNameFn(region.getArgument(0), ivName);
33273494
}
33283495

33293496
void CanonicalLoopOp::print(OpAsmPrinter &p) {

0 commit comments

Comments
 (0)