Skip to content

Commit a365084

Browse files
authored
[mlir][python,CAPI] expose Op::isBeforeInBlock (#150271)
1 parent e89e678 commit a365084

File tree

5 files changed

+37
-0
lines changed

5 files changed

+37
-0
lines changed

mlir/include/mlir-c/IR.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -813,6 +813,13 @@ MLIR_CAPI_EXPORTED void mlirOperationMoveAfter(MlirOperation op,
813813
MLIR_CAPI_EXPORTED void mlirOperationMoveBefore(MlirOperation op,
814814
MlirOperation other);
815815

816+
/// Given an operation 'other' that is within the same parent block, return
817+
/// whether the current operation is before 'other' in the operation list
818+
/// of the parent block.
819+
/// Note: This function has an average complexity of O(1), but worst case may
820+
/// take O(N) where N is the number of operations within the parent block.
821+
MLIR_CAPI_EXPORTED bool mlirOperationIsBeforeInBlock(MlirOperation op,
822+
MlirOperation other);
816823
/// Operation walk result.
817824
typedef enum MlirWalkResult {
818825
MlirWalkResultAdvance,

mlir/lib/Bindings/Python/IRCore.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1454,6 +1454,14 @@ void PyOperationBase::moveBefore(PyOperationBase &other) {
14541454
operation.parentKeepAlive = otherOp.parentKeepAlive;
14551455
}
14561456

1457+
bool PyOperationBase::isBeforeInBlock(PyOperationBase &other) {
1458+
PyOperation &operation = getOperation();
1459+
PyOperation &otherOp = other.getOperation();
1460+
operation.checkValid();
1461+
otherOp.checkValid();
1462+
return mlirOperationIsBeforeInBlock(operation, otherOp);
1463+
}
1464+
14571465
bool PyOperationBase::verify() {
14581466
PyOperation &op = getOperation();
14591467
PyMlirContext::ErrorCapture errors(op.getContext());
@@ -3409,6 +3417,13 @@ void mlir::python::populateIRCore(nb::module_ &m) {
34093417
.def("move_before", &PyOperationBase::moveBefore, nb::arg("other"),
34103418
"Puts self immediately before the other operation in its parent "
34113419
"block.")
3420+
.def("is_before_in_block", &PyOperationBase::isBeforeInBlock,
3421+
nb::arg("other"),
3422+
"Given an operation 'other' that is within the same parent block, "
3423+
"return"
3424+
"whether the current operation is before 'other' in the operation "
3425+
"list"
3426+
"of the parent block.")
34123427
.def(
34133428
"clone",
34143429
[](PyOperationBase &self, nb::object ip) {

mlir/lib/Bindings/Python/IRModule.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -624,6 +624,13 @@ class PyOperationBase {
624624
void moveAfter(PyOperationBase &other);
625625
void moveBefore(PyOperationBase &other);
626626

627+
/// Given an operation 'other' that is within the same parent block, return
628+
/// whether the current operation is before 'other' in the operation list
629+
/// of the parent block.
630+
/// Note: This function has an average complexity of O(1), but worst case may
631+
/// take O(N) where N is the number of operations within the parent block.
632+
bool isBeforeInBlock(PyOperationBase &other);
633+
627634
/// Verify the operation. Throws `MLIRError` if verification fails, and
628635
/// returns `true` otherwise.
629636
bool verify();

mlir/lib/CAPI/IR/IR.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -850,6 +850,10 @@ void mlirOperationMoveBefore(MlirOperation op, MlirOperation other) {
850850
return unwrap(op)->moveBefore(unwrap(other));
851851
}
852852

853+
bool mlirOperationIsBeforeInBlock(MlirOperation op, MlirOperation other) {
854+
return unwrap(op)->isBeforeInBlock(unwrap(other));
855+
}
856+
853857
static mlir::WalkResult unwrap(MlirWalkResult result) {
854858
switch (result) {
855859
case MlirWalkResultAdvance:

mlir/test/python/ir/operation.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -978,8 +978,12 @@ def testModuleMerge():
978978
foo = m1.body.operations[0]
979979
bar = m2.body.operations[0]
980980
qux = m2.body.operations[1]
981+
assert bar.is_before_in_block(qux)
981982
bar.move_before(foo)
983+
assert bar.is_before_in_block(foo)
982984
qux.move_after(foo)
985+
assert bar.is_before_in_block(qux)
986+
assert foo.is_before_in_block(qux)
983987

984988
# CHECK: module
985989
# CHECK: func private @bar

0 commit comments

Comments
 (0)