Skip to content

Commit 2e40c56

Browse files
amd-eochoalokuhar
andauthored
[MLIR][Transform] Prefer entry points in current module (#151323)
The transform interpreter previously looked for the entry point using a recursive walk in pre-order. This makes it so that any named_sequence operation with an arbitrary level of nested-ness will be used as the entry point for the transform interpreter as long as it is placed before another one. This change makes it so that code like the one reported in #119578 works as expected. Closes #119578 Some comments: alternatively, it would also be possible to solve this issue in a slightly more elegant manner. We could define a new walker iterator that iterates through the operations in a breadth first search. --------- Co-authored-by: Jakub Kuderski <[email protected]>
1 parent 2bbc614 commit 2e40c56

File tree

2 files changed

+99
-10
lines changed

2 files changed

+99
-10
lines changed

mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp

Lines changed: 76 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -121,23 +121,89 @@ ModuleOp transform::detail::getPreloadedTransformModule(MLIRContext *context) {
121121
->getLibraryModule();
122122
}
123123

124+
static transform::TransformOpInterface
125+
findTransformEntryPointNonRecursive(Operation *op, StringRef entryPoint) {
126+
for (Region &region : op->getRegions()) {
127+
for (Block &block : region.getBlocks()) {
128+
for (auto namedSequenceOp : block.getOps<transform::NamedSequenceOp>()) {
129+
if (namedSequenceOp.getSymName() == entryPoint) {
130+
return cast<transform::TransformOpInterface>(
131+
namedSequenceOp.getOperation());
132+
}
133+
}
134+
}
135+
}
136+
return nullptr;
137+
}
138+
139+
static transform::TransformOpInterface
140+
findTransformEntryPointRecursive(Operation *op, StringRef entryPoint) {
141+
transform::TransformOpInterface transform = nullptr;
142+
op->walk<WalkOrder::PreOrder>(
143+
[&](transform::NamedSequenceOp namedSequenceOp) {
144+
if (namedSequenceOp.getSymName() == entryPoint) {
145+
transform = cast<transform::TransformOpInterface>(
146+
namedSequenceOp.getOperation());
147+
return WalkResult::interrupt();
148+
}
149+
return WalkResult::advance();
150+
});
151+
return transform;
152+
}
153+
154+
// Will look for the transform's entry point favouring NamedSequenceOps
155+
// ops that exist within the operation without the need for nesting.
156+
// If no operation exists in the blocks owned by op, then it will recursively
157+
// walk the op in preorder and find the first NamedSequenceOp that matches
158+
// the entry point's name.
159+
//
160+
// This allows for the following two use cases:
161+
// 1. op is a module annotated with the transform.with_named_sequence attribute
162+
// that has an entry point in its block. E.g.,
163+
//
164+
// ```mlir
165+
// module {transform.with_named_sequence} {
166+
// transform.named_sequence @__transform_main(%arg0 : !transform.any_op) ->
167+
// () {
168+
// transform.yield
169+
// }
170+
// }
171+
// ```
172+
//
173+
// 2. op is a program which contains a nested module annotated with the
174+
// transform.with_named_sequence attribute. E.g.,
175+
//
176+
// ```mlir
177+
// module {
178+
// func.func @foo () {
179+
// }
180+
//
181+
// module {transform.with_named_sequence} {
182+
// transform.named_sequence @__transform_main(%arg0 : !transform.any_op)
183+
// -> () {
184+
// transform.yield
185+
// }
186+
// }
187+
// }
188+
// ```
189+
static transform::TransformOpInterface
190+
findTransformEntryPointInOp(Operation *op, StringRef entryPoint) {
191+
transform::TransformOpInterface transform =
192+
findTransformEntryPointNonRecursive(op, entryPoint);
193+
if (!transform)
194+
transform = findTransformEntryPointRecursive(op, entryPoint);
195+
return transform;
196+
}
197+
124198
transform::TransformOpInterface
125199
transform::detail::findTransformEntryPoint(Operation *root, ModuleOp module,
126200
StringRef entryPoint) {
127201
SmallVector<Operation *, 2> l{root};
128202
if (module)
129203
l.push_back(module);
130204
for (Operation *op : l) {
131-
transform::TransformOpInterface transform = nullptr;
132-
op->walk<WalkOrder::PreOrder>(
133-
[&](transform::NamedSequenceOp namedSequenceOp) {
134-
if (namedSequenceOp.getSymName() == entryPoint) {
135-
transform = cast<transform::TransformOpInterface>(
136-
namedSequenceOp.getOperation());
137-
return WalkResult::interrupt();
138-
}
139-
return WalkResult::advance();
140-
});
205+
TransformOpInterface transform =
206+
findTransformEntryPointInOp(op, entryPoint);
141207
if (transform)
142208
return transform;
143209
}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
// RUN: mlir-opt %s --transform-interpreter | FileCheck %s
2+
3+
module @td_module_4 attributes {transform.with_named_sequence} {
4+
module @foo_module attributes {transform.with_named_sequence} {
5+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) -> () {
6+
// CHECK: IR printer: foo_module top-level
7+
transform.print {name="foo_module"}
8+
transform.yield
9+
}
10+
}
11+
module @bar_module attributes {transform.with_named_sequence} {
12+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) -> () {
13+
// CHECK: IR printer: bar_module top-level
14+
transform.print {name="bar_module"}
15+
transform.yield
16+
}
17+
}
18+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) -> () {
19+
transform.include @foo_module::@__transform_main failures(suppress) (%arg0) : (!transform.any_op) -> ()
20+
transform.include @bar_module::@__transform_main failures(suppress) (%arg0) : (!transform.any_op) -> ()
21+
transform.yield
22+
}
23+
}

0 commit comments

Comments
 (0)