Skip to content

Commit 9a0ed70

Browse files
authored
[MLIR][Python] bind InsertionPointAfter (#157156)
1 parent 46d8fdd commit 9a0ed70

File tree

3 files changed

+65
-5
lines changed

3 files changed

+65
-5
lines changed

mlir/lib/Bindings/Python/IRCore.cpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2019,7 +2019,7 @@ PyOpView::PyOpView(const nb::object &operationObject)
20192019
// PyInsertionPoint.
20202020
//------------------------------------------------------------------------------
20212021

2022-
PyInsertionPoint::PyInsertionPoint(PyBlock &block) : block(block) {}
2022+
PyInsertionPoint::PyInsertionPoint(const PyBlock &block) : block(block) {}
20232023

20242024
PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase)
20252025
: refOperation(beforeOperationBase.getOperation().getRef()),
@@ -2073,6 +2073,17 @@ PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) {
20732073
return PyInsertionPoint{block, std::move(terminatorOpRef)};
20742074
}
20752075

2076+
PyInsertionPoint PyInsertionPoint::after(PyOperationBase &op) {
2077+
PyOperation &operation = op.getOperation();
2078+
PyBlock block = operation.getBlock();
2079+
MlirOperation nextOperation = mlirOperationGetNextInBlock(operation);
2080+
if (mlirOperationIsNull(nextOperation))
2081+
return PyInsertionPoint(block);
2082+
PyOperationRef nextOpRef = PyOperation::forOperation(
2083+
block.getParentOperation()->getContext(), nextOperation);
2084+
return PyInsertionPoint{block, std::move(nextOpRef)};
2085+
}
2086+
20762087
nb::object PyInsertionPoint::contextEnter(nb::object insertPoint) {
20772088
return PyThreadContextEntry::pushInsertionPoint(insertPoint);
20782089
}
@@ -3861,6 +3872,8 @@ void mlir::python::populateIRCore(nb::module_ &m) {
38613872
nb::arg("block"), "Inserts at the beginning of the block.")
38623873
.def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator,
38633874
nb::arg("block"), "Inserts before the block terminator.")
3875+
.def_static("after", &PyInsertionPoint::after, nb::arg("operation"),
3876+
"Inserts after the operation.")
38643877
.def("insert", &PyInsertionPoint::insert, nb::arg("operation"),
38653878
"Inserts an operation.")
38663879
.def_prop_ro(

mlir/lib/Bindings/Python/IRModule.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -821,14 +821,17 @@ class PyInsertionPoint {
821821
public:
822822
/// Creates an insertion point positioned after the last operation in the
823823
/// block, but still inside the block.
824-
PyInsertionPoint(PyBlock &block);
824+
PyInsertionPoint(const PyBlock &block);
825825
/// Creates an insertion point positioned before a reference operation.
826826
PyInsertionPoint(PyOperationBase &beforeOperationBase);
827827

828828
/// Shortcut to create an insertion point at the beginning of the block.
829829
static PyInsertionPoint atBlockBegin(PyBlock &block);
830830
/// Shortcut to create an insertion point before the block terminator.
831831
static PyInsertionPoint atBlockTerminator(PyBlock &block);
832+
/// Shortcut to create an insertion point to the node after the specified
833+
/// operation.
834+
static PyInsertionPoint after(PyOperationBase &op);
832835

833836
/// Inserts an operation.
834837
void insert(PyOperationBase &operationBase);

mlir/test/python/ir/insertion_point.py

Lines changed: 47 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,34 @@ def test_insert_before_operation():
6363
run(test_insert_before_operation)
6464

6565

66+
# CHECK-LABEL: TEST: test_insert_after_operation
67+
def test_insert_after_operation():
68+
ctx = Context()
69+
ctx.allow_unregistered_dialects = True
70+
with Location.unknown(ctx):
71+
module = Module.parse(
72+
r"""
73+
func.func @foo() -> () {
74+
"custom.op1"() : () -> ()
75+
"custom.op2"() : () -> ()
76+
}
77+
"""
78+
)
79+
entry_block = module.body.operations[0].regions[0].blocks[0]
80+
custom_op1 = entry_block.operations[0]
81+
custom_op2 = entry_block.operations[1]
82+
InsertionPoint.after(custom_op1).insert(Operation.create("custom.op3"))
83+
InsertionPoint.after(custom_op2).insert(Operation.create("custom.op4"))
84+
# CHECK: "custom.op1"
85+
# CHECK: "custom.op3"
86+
# CHECK: "custom.op2"
87+
# CHECK: "custom.op4"
88+
module.operation.print()
89+
90+
91+
run(test_insert_after_operation)
92+
93+
6694
# CHECK-LABEL: TEST: test_insert_at_block_begin
6795
def test_insert_at_block_begin():
6896
ctx = Context()
@@ -111,14 +139,24 @@ def test_insert_at_terminator():
111139
"""
112140
)
113141
entry_block = module.body.operations[0].regions[0].blocks[0]
142+
return_op = entry_block.operations[1]
114143
ip = InsertionPoint.at_block_terminator(entry_block)
115144
assert ip.block == entry_block
116-
assert ip.ref_operation == entry_block.operations[1]
117-
ip.insert(Operation.create("custom.op2"))
145+
assert ip.ref_operation == return_op
146+
custom_op2 = Operation.create("custom.op2")
147+
ip.insert(custom_op2)
148+
InsertionPoint.after(custom_op2).insert(Operation.create("custom.op3"))
118149
# CHECK: "custom.op1"
119150
# CHECK: "custom.op2"
151+
# CHECK: "custom.op3"
120152
module.operation.print()
121153

154+
try:
155+
InsertionPoint.after(return_op).insert(Operation.create("custom.op4"))
156+
except IndexError as e:
157+
# CHECK: ERROR: Cannot insert operation at the end of a block that already has a terminator.
158+
print(f"ERROR: {e}")
159+
122160

123161
run(test_insert_at_terminator)
124162

@@ -187,10 +225,16 @@ def test_insertion_point_context():
187225
with InsertionPoint(entry_block):
188226
Operation.create("custom.op2")
189227
with InsertionPoint.at_block_begin(entry_block):
190-
Operation.create("custom.opa")
228+
custom_opa = Operation.create("custom.opa")
191229
Operation.create("custom.opb")
192230
Operation.create("custom.op3")
231+
with InsertionPoint.after(custom_opa):
232+
Operation.create("custom.op4")
233+
Operation.create("custom.op5")
234+
193235
# CHECK: "custom.opa"
236+
# CHECK: "custom.op4"
237+
# CHECK: "custom.op5"
194238
# CHECK: "custom.opb"
195239
# CHECK: "custom.op1"
196240
# CHECK: "custom.op2"

0 commit comments

Comments
 (0)