From 9af755cd4794434439074e2e80e5ac0a43aad78a Mon Sep 17 00:00:00 2001 From: max Date: Fri, 17 May 2024 22:36:41 -0500 Subject: [PATCH 01/12] [mlir][python] wip remove liveOpeartions --- mlir/include/mlir-c/IR.h | 2 + mlir/lib/Bindings/Python/IRCore.cpp | 174 +++--------------- mlir/lib/Bindings/Python/IRModule.h | 53 ------ mlir/lib/Bindings/Python/Pass.cpp | 8 +- .../Bindings/Python/TransformInterpreter.cpp | 1 - mlir/lib/CAPI/IR/IR.cpp | 4 + mlir/test/python/ir/module.py | 22 +-- mlir/test/python/ir/operation.py | 3 +- mlir/test/python/ir/symbol_table.py | 8 - mlir/test/python/pass_manager.py | 27 +-- 10 files changed, 42 insertions(+), 260 deletions(-) diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h index 71c7d4378677f..d05f91d7e3b12 100644 --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -415,6 +415,8 @@ MLIR_CAPI_EXPORTED MlirOperation mlirModuleGetOperation(MlirModule module); /// The returned module is null when the input operation was not a ModuleOp. MLIR_CAPI_EXPORTED MlirModule mlirModuleFromOperation(MlirOperation op); +MLIR_CAPI_EXPORTED bool mlirModuleEqual(MlirModule lhs, MlirModule rhs); + //===----------------------------------------------------------------------===// // Operation state. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 15889ddabd2c4..789891f495217 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -702,84 +702,6 @@ size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); } -size_t PyMlirContext::getLiveOperationCount() { - nb::ft_lock_guard lock(liveOperationsMutex); - return liveOperations.size(); -} - -std::vector PyMlirContext::getLiveOperationObjects() { - std::vector liveObjects; - nb::ft_lock_guard lock(liveOperationsMutex); - for (auto &entry : liveOperations) - liveObjects.push_back(entry.second.second); - return liveObjects; -} - -size_t PyMlirContext::clearLiveOperations() { - - LiveOperationMap operations; - { - nb::ft_lock_guard lock(liveOperationsMutex); - std::swap(operations, liveOperations); - } - for (auto &op : operations) - op.second.second->setInvalid(); - size_t numInvalidated = operations.size(); - return numInvalidated; -} - -void PyMlirContext::clearOperation(MlirOperation op) { - PyOperation *pyOp; - { - nb::ft_lock_guard lock(liveOperationsMutex); - auto it = liveOperations.find(op.ptr); - if (it == liveOperations.end()) { - return; - } - pyOp = it->second.second; - liveOperations.erase(it); - } - pyOp->setInvalid(); -} - -void PyMlirContext::clearOperationsInside(PyOperationBase &op) { - using callBackData = struct { - PyOperation &rootOp; - bool rootSeen; - }; - callBackData data{op.getOperation(), false}; - // Mark all ops below the op that the passmanager will be rooted - // at (but not op itself - note the preorder) as invalid. - MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op, - void *userData) { - callBackData *data = static_cast(userData); - if (LLVM_LIKELY(data->rootSeen)) - data->rootOp.getOperation().getContext()->clearOperation(op); - else - data->rootSeen = true; - return MlirWalkResult::MlirWalkResultAdvance; - }; - mlirOperationWalk(op.getOperation(), invalidatingCallback, - static_cast(&data), MlirWalkPreOrder); -} -void PyMlirContext::clearOperationsInside(MlirOperation op) { - PyOperationRef opRef = PyOperation::forOperation(getRef(), op); - clearOperationsInside(opRef->getOperation()); -} - -void PyMlirContext::clearOperationAndInside(PyOperationBase &op) { - MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op, - void *userData) { - PyMlirContextRef &contextRef = *static_cast(userData); - contextRef->clearOperation(op); - return MlirWalkResult::MlirWalkResultAdvance; - }; - mlirOperationWalk(op.getOperation(), invalidatingCallback, - &op.getOperation().getContext(), MlirWalkPreOrder); -} - -size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); } - nb::object PyMlirContext::contextEnter(nb::object context) { return PyThreadContextEntry::pushContext(context); } @@ -1151,38 +1073,20 @@ PyLocation &DefaultingPyLocation::resolve() { PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module) : BaseContextObject(std::move(contextRef)), module(module) {} -PyModule::~PyModule() { - nb::gil_scoped_acquire acquire; - auto &liveModules = getContext()->liveModules; - assert(liveModules.count(module.ptr) == 1 && - "destroying module not in live map"); - liveModules.erase(module.ptr); - mlirModuleDestroy(module); -} +PyModule::~PyModule() { mlirModuleDestroy(module); } PyModuleRef PyModule::forModule(MlirModule module) { MlirContext context = mlirModuleGetContext(module); PyMlirContextRef contextRef = PyMlirContext::forContext(context); - nb::gil_scoped_acquire acquire; - auto &liveModules = contextRef->liveModules; - auto it = liveModules.find(module.ptr); - if (it == liveModules.end()) { - // Create. - PyModule *unownedModule = new PyModule(std::move(contextRef), module); - // Note that the default return value policy on cast is automatic_reference, - // which does not take ownership (delete will not be called). - // Just be explicit. - nb::object pyRef = nb::cast(unownedModule, nb::rv_policy::take_ownership); - unownedModule->handle = pyRef; - liveModules[module.ptr] = - std::make_pair(unownedModule->handle, unownedModule); - return PyModuleRef(unownedModule, std::move(pyRef)); - } - // Use existing. - PyModule *existing = it->second.second; - nb::object pyRef = nb::borrow(it->second.first); - return PyModuleRef(existing, std::move(pyRef)); + // Create. + PyModule *unownedModule = new PyModule(std::move(contextRef), module); + // Note that the default return value policy on cast is automatic_reference, + // which does not take ownership (delete will not be called). + // Just be explicit. + nb::object pyRef = nb::cast(unownedModule, nb::rv_policy::take_ownership); + unownedModule->handle = pyRef; + return PyModuleRef(unownedModule, std::move(pyRef)); } nb::object PyModule::createFromCapsule(nb::object capsule) { @@ -1207,15 +1111,11 @@ PyOperation::~PyOperation() { // If the operation has already been invalidated there is nothing to do. if (!valid) return; - - // Otherwise, invalidate the operation and remove it from live map when it is - // attached. - if (isAttached()) { - getContext()->clearOperation(*this); - } else { - // And destroy it when it is detached, i.e. owned by Python, in which case - // all nested operations must be invalidated at removed from the live map as - // well. + // Otherwise, invalidate the operation when it is attached. + if (isAttached()) + setInvalid(); + else { + // And destroy it when it is detached, i.e. owned by Python. erase(); } } @@ -1252,35 +1152,16 @@ PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef, PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef, MlirOperation operation, nb::object parentKeepAlive) { - nb::ft_lock_guard lock(contextRef->liveOperationsMutex); - auto &liveOperations = contextRef->liveOperations; - auto it = liveOperations.find(operation.ptr); - if (it == liveOperations.end()) { - // Create. - PyOperationRef result = createInstance(std::move(contextRef), operation, - std::move(parentKeepAlive)); - liveOperations[operation.ptr] = - std::make_pair(result.getObject(), result.get()); - return result; - } - // Use existing. - PyOperation *existing = it->second.second; - nb::object pyRef = nb::borrow(it->second.first); - return PyOperationRef(existing, std::move(pyRef)); + // Create. + return createInstance(std::move(contextRef), operation, + std::move(parentKeepAlive)); } PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef, MlirOperation operation, nb::object parentKeepAlive) { - nb::ft_lock_guard lock(contextRef->liveOperationsMutex); - auto &liveOperations = contextRef->liveOperations; - assert(liveOperations.count(operation.ptr) == 0 && - "cannot create detached operation that already exists"); - (void)liveOperations; PyOperationRef created = createInstance(std::move(contextRef), operation, std::move(parentKeepAlive)); - liveOperations[operation.ptr] = - std::make_pair(created.getObject(), created.get()); created->attached = false; return created; } @@ -1652,7 +1533,7 @@ nb::object PyOperation::createOpView() { void PyOperation::erase() { checkValid(); - getContext()->clearOperationAndInside(*this); + setInvalid(); mlirOperationDestroy(operation); } @@ -3023,14 +2904,6 @@ void mlir::python::populateIRCore(nb::module_ &m) { PyMlirContextRef ref = PyMlirContext::forContext(self.get()); return ref.releaseObject(); }) - .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount) - .def("_get_live_operation_objects", - &PyMlirContext::getLiveOperationObjects) - .def("_clear_live_operations", &PyMlirContext::clearLiveOperations) - .def("_clear_live_operations_inside", - nb::overload_cast( - &PyMlirContext::clearOperationsInside)) - .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount) .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyMlirContext::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule) .def("__enter__", &PyMlirContext::contextEnter) @@ -3428,7 +3301,13 @@ void mlir::python::populateIRCore(nb::module_ &m) { // Defer to the operation's __str__. return self.attr("operation").attr("__str__")(); }, - kOperationStrDunderDocstring); + kOperationStrDunderDocstring) + .def( + "__eq__", + [](PyModule &self, PyModule &other) { + return mlirModuleEqual(self.get(), other.get()); + }, + "other"_a); //---------------------------------------------------------------------------- // Mapping of Operation. @@ -3440,7 +3319,8 @@ void mlir::python::populateIRCore(nb::module_ &m) { }) .def("__eq__", [](PyOperationBase &self, PyOperationBase &other) { - return &self.getOperation() == &other.getOperation(); + return mlirOperationEqual(self.getOperation().get(), + other.getOperation().get()); }) .def("__eq__", [](PyOperationBase &self, nb::object other) { return false; }) diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index 6617b41cc916c..553da2ef52880 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -218,40 +218,6 @@ class PyMlirContext { /// Gets the count of live context objects. Used for testing. static size_t getLiveCount(); - /// Get a list of Python objects which are still in the live context map. - std::vector getLiveOperationObjects(); - - /// Gets the count of live operations associated with this context. - /// Used for testing. - size_t getLiveOperationCount(); - - /// Clears the live operations map, returning the number of entries which were - /// invalidated. To be used as a safety mechanism so that API end-users can't - /// corrupt by holding references they shouldn't have accessed in the first - /// place. - size_t clearLiveOperations(); - - /// Removes an operation from the live operations map and sets it invalid. - /// This is useful for when some non-bindings code destroys the operation and - /// the bindings need to made aware. For example, in the case when pass - /// manager is run. - /// - /// Note that this does *NOT* clear the nested operations. - void clearOperation(MlirOperation op); - - /// Clears all operations nested inside the given op using - /// `clearOperation(MlirOperation)`. - void clearOperationsInside(PyOperationBase &op); - void clearOperationsInside(MlirOperation op); - - /// Clears the operaiton _and_ all operations inside using - /// `clearOperation(MlirOperation)`. - void clearOperationAndInside(PyOperationBase &op); - - /// Gets the count of live modules associated with this context. - /// Used for testing. - size_t getLiveModuleCount(); - /// Enter and exit the context manager. static nanobind::object contextEnter(nanobind::object context); void contextExit(const nanobind::object &excType, @@ -278,25 +244,6 @@ class PyMlirContext { static nanobind::ft_mutex live_contexts_mutex; static LiveContextMap &getLiveContexts(); - // Interns all live modules associated with this context. Modules tracked - // in this map are valid. When a module is invalidated, it is removed - // from this map, and while it still exists as an instance, any - // attempt to access it will raise an error. - using LiveModuleMap = - llvm::DenseMap>; - LiveModuleMap liveModules; - - // Interns all live operations associated with this context. Operations - // tracked in this map are valid. When an operation is invalidated, it is - // removed from this map, and while it still exists as an instance, any - // attempt to access it will raise an error. - using LiveOperationMap = - llvm::DenseMap>; - nanobind::ft_mutex liveOperationsMutex; - - // Guarded by liveOperationsMutex in free-threading mode. - LiveOperationMap liveOperations; - bool emitErrorDiagnostics = false; MlirContext context; diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp index 20017e25b69bb..817479ee2421b 100644 --- a/mlir/lib/Bindings/Python/Pass.cpp +++ b/mlir/lib/Bindings/Python/Pass.cpp @@ -159,11 +159,7 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) { "ValueError if the pipeline can't be parsed.") .def( "run", - [](PyPassManager &passManager, PyOperationBase &op, - bool invalidateOps) { - if (invalidateOps) { - op.getOperation().getContext()->clearOperationsInside(op); - } + [](PyPassManager &passManager, PyOperationBase &op) { // Actually run the pass manager. PyMlirContext::ErrorCapture errors(op.getOperation().getContext()); MlirLogicalResult status = mlirPassManagerRunOnOp( @@ -172,7 +168,7 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) { throw MLIRError("Failure while executing pass pipeline", errors.take()); }, - "operation"_a, "invalidate_ops"_a = true, + "operation"_a, "Run the pass manager on the provided operation, raising an " "MLIRError on failure.") .def( diff --git a/mlir/lib/Bindings/Python/TransformInterpreter.cpp b/mlir/lib/Bindings/Python/TransformInterpreter.cpp index f9b0fed62778f..920bca886f617 100644 --- a/mlir/lib/Bindings/Python/TransformInterpreter.cpp +++ b/mlir/lib/Bindings/Python/TransformInterpreter.cpp @@ -67,7 +67,6 @@ static void populateTransformInterpreterSubmodule(nb::module_ &m) { // root. This is awkward, but we don't have access to PyMlirContext // object here otherwise. nb::object obj = nb::cast(payloadRoot); - obj.attr("context").attr("_clear_live_operations_inside")(payloadRoot); MlirLogicalResult result = mlirTransformApplyNamedSequence( payloadRoot, transformRoot, transformModule, options.options); diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index 8491553dab76f..c7069f0017b5d 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -465,6 +465,10 @@ MlirModule mlirModuleFromOperation(MlirOperation op) { return wrap(dyn_cast(unwrap(op))); } +bool mlirModuleEqual(MlirModule lhs, MlirModule rhs) { + return unwrap(lhs) == unwrap(rhs); +} + //===----------------------------------------------------------------------===// // Operation state API. //===----------------------------------------------------------------------===// diff --git a/mlir/test/python/ir/module.py b/mlir/test/python/ir/module.py index 6065e59fd6ed9..449e25d4edde2 100644 --- a/mlir/test/python/ir/module.py +++ b/mlir/test/python/ir/module.py @@ -121,27 +121,17 @@ def testRoundtripBinary(): def testModuleOperation(): ctx = Context() module = Module.parse(r"""module @successfulParse {}""", ctx) - assert ctx._get_live_module_count() == 1 op1 = module.operation - assert ctx._get_live_operation_count() == 1 - live_ops = ctx._get_live_operation_objects() - assert len(live_ops) == 1 - assert live_ops[0] is op1 - live_ops = None # CHECK: module @successfulParse print(op1) # Ensure that operations are the same on multiple calls. op2 = module.operation - assert ctx._get_live_operation_count() == 1 - assert op1 is op2 + assert not op1 is op2 + assert op1 == op2 # Test live operation clearing. op1 = module.operation - assert ctx._get_live_operation_count() == 1 - num_invalidated = ctx._clear_live_operations() - assert num_invalidated == 1 - assert ctx._get_live_operation_count() == 0 op1 = None gc.collect() op1 = module.operation @@ -155,9 +145,6 @@ def testModuleOperation(): op1 = None op2 = None gc.collect() - print("LIVE OPERATIONS:", ctx._get_live_operation_count()) - assert ctx._get_live_operation_count() == 0 - assert ctx._get_live_module_count() == 0 # CHECK-LABEL: TEST: testModuleCapsule @@ -165,16 +152,15 @@ def testModuleOperation(): def testModuleCapsule(): ctx = Context() module = Module.parse(r"""module @successfulParse {}""", ctx) - assert ctx._get_live_module_count() == 1 # CHECK: "mlir.ir.Module._CAPIPtr" module_capsule = module._CAPIPtr print(module_capsule) module_dup = Module._CAPICreate(module_capsule) - assert module is module_dup + assert not module is module_dup + assert module == module_dup assert module_dup.context is ctx # Gc and verify destructed. module = None module_capsule = None module_dup = None gc.collect() - assert ctx._get_live_module_count() == 0 diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py index bf16e3f75d60d..bb74b6bc5e5ed 100644 --- a/mlir/test/python/ir/operation.py +++ b/mlir/test/python/ir/operation.py @@ -907,7 +907,8 @@ def testCapsuleConversions(): m_capsule = m._CAPIPtr assert '"mlir.ir.Operation._CAPIPtr"' in repr(m_capsule) m2 = Operation._CAPICreate(m_capsule) - assert m2 is m + assert not m2 is m + assert m2 == m # CHECK-LABEL: TEST: testOperationErase diff --git a/mlir/test/python/ir/symbol_table.py b/mlir/test/python/ir/symbol_table.py index 8b6d7ea5a197d..7afd539271d21 100644 --- a/mlir/test/python/ir/symbol_table.py +++ b/mlir/test/python/ir/symbol_table.py @@ -56,14 +56,6 @@ def testSymbolTableInsert(): print(m1) assert "bar" not in symbol_table - try: - print(bar) - except RuntimeError as e: - if "the operation has been invalidated" not in str(e): - raise - else: - assert False, "expected RuntimeError due to invalidated operation" - qux = m2.body.operations[0] m1.body.append(qux) symbol_table.insert(qux) diff --git a/mlir/test/python/pass_manager.py b/mlir/test/python/pass_manager.py index e26d42bb32913..aea8803a57bc5 100644 --- a/mlir/test/python/pass_manager.py +++ b/mlir/test/python/pass_manager.py @@ -176,14 +176,6 @@ def testRunPipelineError(): @run def testPostPassOpInvalidation(): with Context() as ctx: - log_op_count = lambda: log("live ops:", ctx._get_live_operation_count()) - - # CHECK: invalidate_ops=False - log("invalidate_ops=False") - - # CHECK: live ops: 0 - log_op_count() - module = ModuleOp.parse( """ module { @@ -196,9 +188,6 @@ def testPostPassOpInvalidation(): """ ) - # CHECK: live ops: 1 - log_op_count() - outer_const_op = module.body.operations[0] # CHECK: %[[VAL0:.*]] = arith.constant 10 : i64 log(outer_const_op) @@ -214,12 +203,7 @@ def testPostPassOpInvalidation(): # CHECK: %[[VAL1]] = arith.constant 10 : i64 log(inner_const_op) - # CHECK: live ops: 4 - log_op_count() - - PassManager.parse("builtin.module(canonicalize)").run( - module, invalidate_ops=False - ) + PassManager.parse("builtin.module(canonicalize)").run(module) # CHECK: func.func @foo() { # CHECK: return # CHECK: } @@ -233,9 +217,6 @@ def testPostPassOpInvalidation(): # CHECK: invalidate_ops=True log("invalidate_ops=True") - # CHECK: live ops: 4 - log_op_count() - module = ModuleOp.parse( """ module { @@ -251,14 +232,8 @@ def testPostPassOpInvalidation(): func_op = module.body.operations[1] inner_const_op = func_op.body.blocks[0].operations[0] - # CHECK: live ops: 4 - log_op_count() - PassManager.parse("builtin.module(canonicalize)").run(module) - # CHECK: live ops: 1 - log_op_count() - try: log(func_op) except RuntimeError as e: From 809f0dc5c667ce39e72843e50e14ed3fc10dad8d Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Sat, 23 Aug 2025 14:30:14 -0400 Subject: [PATCH 02/12] "fix" testPostPassOpInvalidation --- mlir/test/python/pass_manager.py | 21 --------------------- 1 file changed, 21 deletions(-) diff --git a/mlir/test/python/pass_manager.py b/mlir/test/python/pass_manager.py index aea8803a57bc5..0896cd9784641 100644 --- a/mlir/test/python/pass_manager.py +++ b/mlir/test/python/pass_manager.py @@ -228,30 +228,9 @@ def testPostPassOpInvalidation(): } """ ) - outer_const_op = module.body.operations[0] - func_op = module.body.operations[1] - inner_const_op = func_op.body.blocks[0].operations[0] PassManager.parse("builtin.module(canonicalize)").run(module) - try: - log(func_op) - except RuntimeError as e: - # CHECK: the operation has been invalidated - log(e) - - try: - log(outer_const_op) - except RuntimeError as e: - # CHECK: the operation has been invalidated - log(e) - - try: - log(inner_const_op) - except RuntimeError as e: - # CHECK: the operation has been invalidated - log(e) - # CHECK: func.func @foo() { # CHECK: return # CHECK: } From fd8a12aa687bc0cb021bd1de0a248e132744c8e3 Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Sat, 23 Aug 2025 15:06:07 -0400 Subject: [PATCH 03/12] try to fix testModuleCapsule --- mlir/lib/Bindings/Python/IRCore.cpp | 1 + mlir/lib/Bindings/Python/IRModule.h | 2 ++ mlir/test/python/ir/module.py | 2 ++ 3 files changed, 5 insertions(+) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 789891f495217..b0dc9f17a76b1 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -3222,6 +3222,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { nb::class_(m, "Module", nb::is_weak_referenceable()) .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule) + .def("_clear_mlir_module", &PyModule::clearMlirModule) .def_static( "parse", [](const std::string &moduleAsm, DefaultingPyMlirContext context) { diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index 553da2ef52880..932d46b5fd7ba 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -522,6 +522,8 @@ class PyModule : public BaseContextObject { /// is taken by calling this function. static nanobind::object createFromCapsule(nanobind::object capsule); + void clearMlirModule() { module = {nullptr}; } + private: PyModule(PyMlirContextRef contextRef, MlirModule module); MlirModule module; diff --git a/mlir/test/python/ir/module.py b/mlir/test/python/ir/module.py index 449e25d4edde2..a552eaa662af4 100644 --- a/mlir/test/python/ir/module.py +++ b/mlir/test/python/ir/module.py @@ -158,6 +158,8 @@ def testModuleCapsule(): module_dup = Module._CAPICreate(module_capsule) assert not module is module_dup assert module == module_dup + module._clear_mlir_module() + assert not module == module_dup assert module_dup.context is ctx # Gc and verify destructed. module = None From 8eb4d75524f2dd9dc2123a2d9f0bde0295b6d23d Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Sat, 23 Aug 2025 16:43:49 -0400 Subject: [PATCH 04/12] add check for Operation._CAPICreate --- mlir/test/python/ir/operation.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py index bb74b6bc5e5ed..94f39c0fbd077 100644 --- a/mlir/test/python/ir/operation.py +++ b/mlir/test/python/ir/operation.py @@ -909,6 +909,11 @@ def testCapsuleConversions(): m2 = Operation._CAPICreate(m_capsule) assert not m2 is m assert m2 == m + # Gc and verify destructed. + m = None + m_capsule = None + m2 = None + gc.collect() # CHECK-LABEL: TEST: testOperationErase From 8c722718b034cff71e7737afc22930a2ec5a1e6a Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Sat, 23 Aug 2025 17:17:36 -0400 Subject: [PATCH 05/12] update docs --- mlir/lib/Bindings/Python/IRCore.cpp | 9 ++++++++- mlir/lib/Bindings/Python/IRModule.h | 9 ++++----- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index b0dc9f17a76b1..7f31ea1a7b1c8 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -67,6 +67,12 @@ Returns a new MlirModule or raises an MLIRError if the parsing fails. See also: https://mlir.llvm.org/docs/LangRef/ )"; +static const char kModuleCAPICreate[] = + R"(Creates a Module from a MlirModule wrapped by a capsule (i.e. module._CAPIPtr). +Note this returns a new object BUT _clear_mlir_module(module) must be called to +prevent double-frees (of the underlying mlir::Module). +)"; + static const char kOperationCreateDocstring[] = R"(Creates a new operation. @@ -3221,7 +3227,8 @@ void mlir::python::populateIRCore(nb::module_ &m) { //---------------------------------------------------------------------------- nb::class_(m, "Module", nb::is_weak_referenceable()) .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule) - .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule) + .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule, + kModuleCAPICreate) .def("_clear_mlir_module", &PyModule::clearMlirModule) .def_static( "parse", diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index 932d46b5fd7ba..0cc0459ebc9a0 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -495,8 +495,8 @@ class PyModule; using PyModuleRef = PyObjectRef; class PyModule : public BaseContextObject { public: - /// Returns a PyModule reference for the given MlirModule. This may return - /// a pre-existing or new object. + /// Returns a PyModule reference for the given MlirModule. This always returns + /// a new object. static PyModuleRef forModule(MlirModule module); PyModule(PyModule &) = delete; PyModule(PyMlirContext &&) = delete; @@ -517,9 +517,8 @@ class PyModule : public BaseContextObject { nanobind::object getCapsule(); /// Creates a PyModule from the MlirModule wrapped by a capsule. - /// Note that PyModule instances are uniqued, so the returned object - /// may be a pre-existing object. Ownership of the underlying MlirModule - /// is taken by calling this function. + /// Note this returns a new object BUT clearMlirModule() must be called to + /// prevent double-frees (of the underlying mlir::Module). static nanobind::object createFromCapsule(nanobind::object capsule); void clearMlirModule() { module = {nullptr}; } From 70d6b56d9f26928af5450c9efac62ff2694cf369 Mon Sep 17 00:00:00 2001 From: makslevental Date: Tue, 26 Aug 2025 14:58:22 -0400 Subject: [PATCH 06/12] comments --- mlir/include/mlir-c/IR.h | 1 + mlir/lib/Bindings/Python/IRCore.cpp | 10 ++++++---- mlir/lib/Bindings/Python/MainModule.cpp | 4 ++++ 3 files changed, 11 insertions(+), 4 deletions(-) diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h index d05f91d7e3b12..e97369778b377 100644 --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -415,6 +415,7 @@ MLIR_CAPI_EXPORTED MlirOperation mlirModuleGetOperation(MlirModule module); /// The returned module is null when the input operation was not a ModuleOp. MLIR_CAPI_EXPORTED MlirModule mlirModuleFromOperation(MlirOperation op); +/// Checks if two modules are equal. MLIR_CAPI_EXPORTED bool mlirModuleEqual(MlirModule lhs, MlirModule rhs); //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 7f31ea1a7b1c8..8ab8901cdc41f 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -1087,9 +1087,12 @@ PyModuleRef PyModule::forModule(MlirModule module) { // Create. PyModule *unownedModule = new PyModule(std::move(contextRef), module); - // Note that the default return value policy on cast is automatic_reference, - // which does not take ownership (delete will not be called). - // Just be explicit. + // Note that the default return value policy on cast is `automatic_reference`, + // which means "does not take ownership, does not call delete/dtor". + // We use `take_ownership`, which means "Python will call the C++ destructor + // and delete operator when the Python wrapper is garbage collected", because + // MlirModule actually wraps OwningOpRef (see mlirModuleCreateParse + // etc). nb::object pyRef = nb::cast(unownedModule, nb::rv_policy::take_ownership); unownedModule->handle = pyRef; return PyModuleRef(unownedModule, std::move(pyRef)); @@ -1158,7 +1161,6 @@ PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef, PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef, MlirOperation operation, nb::object parentKeepAlive) { - // Create. return createInstance(std::move(contextRef), operation, std::move(parentKeepAlive)); } diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp index 278847e7ac7f5..d091d6a11ab11 100644 --- a/mlir/lib/Bindings/Python/MainModule.cpp +++ b/mlir/lib/Bindings/Python/MainModule.cpp @@ -139,4 +139,8 @@ NB_MODULE(_mlir, m) { auto passModule = m.def_submodule("passmanager", "MLIR Pass Management Bindings"); populatePassManagerSubmodule(passModule); + + m.def("test_raise_exception", []() { + throw std::runtime_error("wtfbbq"); + }); } From 685fbb5d0191557fe495d9c1274e14632a4a43f6 Mon Sep 17 00:00:00 2001 From: makslevental Date: Tue, 26 Aug 2025 15:15:40 -0400 Subject: [PATCH 07/12] update the docs --- mlir/docs/Bindings/Python.md | 32 +++++++++++++++---------- mlir/lib/Bindings/Python/MainModule.cpp | 4 ---- 2 files changed, 20 insertions(+), 16 deletions(-) diff --git a/mlir/docs/Bindings/Python.md b/mlir/docs/Bindings/Python.md index bef9e7f54948d..2c7ef01e8fb4e 100644 --- a/mlir/docs/Bindings/Python.md +++ b/mlir/docs/Bindings/Python.md @@ -216,13 +216,26 @@ added to an attached operation, they need to be re-parented to the containing module). Due to the validity and parenting accounting needs, `PyOperation` is the owner -for regions and blocks and needs to be a top-level type that we can count on not -aliasing. This let's us do things like selectively invalidating instances when -mutations occur without worrying that there is some alias to the same operation -in the hierarchy. Operations are also the only entity that are allowed to be in -a detached state, and they are interned at the context level so that there is -never more than one Python `mlir.ir.Operation` object for a unique -`MlirOperation`, regardless of how it is obtained. +for regions and blocks. Operations are also the only entity that are allowed to be in +a detached state. + +**Note**: Multiple `PyOperation` objects (i.e., the Python objects themselves) can alias a single `mlir::Operation`. +This means, for example, if you have `py_op1` and `py_op2` which wrap the same `mlir::Operation op` +and you somehow transform `op` (e.g., you run a pass on `op`) then walking the MLIR AST via either/or `py_op1`, `py_op2` +will reflect the same MLIR AST. This is perfectly safe and supported. What is not supported is invalidating any +operation while there exist multiple Python objects wrapping that operation **and then manipulating those wrappers**. +For example if `py_op1` and `py_op2` wrap the same operation under a root `py_op3` and then `py_op3` is +transformed such that the operation referenced (by `py_op1`, `py_op2`) is erased. Then `py_op1`, `py_op2` +become "undefined" in a sense; manipulating them in any way is "formally forbidden". Note, this also applies to +`SymbolTable` mutation, which is considered a transformation of the root `SymbolTable`-supporting operation for the +purposes of the discussion here. The "best practices" recommendation is to structure your code such that + +1. First, query/manipulate various Python wrapper objects `py_op1`, `py_op2`, `py_op3`, etc.; +2. Second, Transform the AST/erase operations/etc. via a single root object; +3. End. + +Ideally this should be done in a function body so that "End" corresponds to the end of the function and there are no +risks of Python wrapper objects leaking/living longer than necessary. The C/C++ API allows for Region/Block to also be detached, but it simplifies the ownership model a lot to eliminate that possibility in this API, allowing the @@ -238,11 +251,6 @@ blocks. We may end up needing an op-local one at some point TBD, depending on how hard it is to guarantee how mutations interact with their Python peer objects. We can cross that bridge easily when we get there. -Module, when used purely from the Python API, can't alias anyway, so we can use -it as a top-level ref type without a live-list for interning. If the API ever -changes such that this cannot be guaranteed (i.e. by letting you marshal a -native-defined Module in), then there would need to be a live table for it too. - ## User-level API ### Context Management diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp index d091d6a11ab11..278847e7ac7f5 100644 --- a/mlir/lib/Bindings/Python/MainModule.cpp +++ b/mlir/lib/Bindings/Python/MainModule.cpp @@ -139,8 +139,4 @@ NB_MODULE(_mlir, m) { auto passModule = m.def_submodule("passmanager", "MLIR Pass Management Bindings"); populatePassManagerSubmodule(passModule); - - m.def("test_raise_exception", []() { - throw std::runtime_error("wtfbbq"); - }); } From 57ad5a11c55d49ef58d5d49cdb30ae093e28cf40 Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Tue, 26 Aug 2025 23:08:59 -0400 Subject: [PATCH 08/12] Update Python.md --- mlir/docs/Bindings/Python.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/docs/Bindings/Python.md b/mlir/docs/Bindings/Python.md index 2c7ef01e8fb4e..c1a866f57b945 100644 --- a/mlir/docs/Bindings/Python.md +++ b/mlir/docs/Bindings/Python.md @@ -216,7 +216,7 @@ added to an attached operation, they need to be re-parented to the containing module). Due to the validity and parenting accounting needs, `PyOperation` is the owner -for regions and blocks. Operations are also the only entity that are allowed to be in +for regions and blocks. Operations are also the only entities which are allowed to be in a detached state. **Note**: Multiple `PyOperation` objects (i.e., the Python objects themselves) can alias a single `mlir::Operation`. @@ -1237,4 +1237,4 @@ The exceptions to the free-threading compatibility: - Usage of `Location.emit_error` is unsafe (due to thread-unsafe `llvm::raw_ostream`). - Usage of `Module.dump` is unsafe (due to thread-unsafe `llvm::raw_ostream`). - Usage of `mlir.dialects.transform.interpreter` is unsafe. -- Usage of `mlir.dialects.gpu` and `gpu-module-to-binary` is unsafe. \ No newline at end of file +- Usage of `mlir.dialects.gpu` and `gpu-module-to-binary` is unsafe. From d319108b651b2bb3f048214691c69a1972ab9e50 Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Tue, 26 Aug 2025 23:13:16 -0400 Subject: [PATCH 09/12] Update Python.md --- mlir/docs/Bindings/Python.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/docs/Bindings/Python.md b/mlir/docs/Bindings/Python.md index c1a866f57b945..1a0e57db94c9b 100644 --- a/mlir/docs/Bindings/Python.md +++ b/mlir/docs/Bindings/Python.md @@ -228,7 +228,7 @@ For example if `py_op1` and `py_op2` wrap the same operation under a root `py_op transformed such that the operation referenced (by `py_op1`, `py_op2`) is erased. Then `py_op1`, `py_op2` become "undefined" in a sense; manipulating them in any way is "formally forbidden". Note, this also applies to `SymbolTable` mutation, which is considered a transformation of the root `SymbolTable`-supporting operation for the -purposes of the discussion here. The "best practices" recommendation is to structure your code such that +purposes of the discussion here. Methodically one can think of this similarly to how STL container iterators are invalidated once the container itself is changed. The "best practices" recommendation is to structure your code such that 1. First, query/manipulate various Python wrapper objects `py_op1`, `py_op2`, `py_op3`, etc.; 2. Second, Transform the AST/erase operations/etc. via a single root object; From 8ee9d2c56a60a9057edae93e0221d5ed585fcbfd Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Tue, 26 Aug 2025 23:15:35 -0400 Subject: [PATCH 10/12] Update Python.md --- mlir/docs/Bindings/Python.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/docs/Bindings/Python.md b/mlir/docs/Bindings/Python.md index 1a0e57db94c9b..8067c5cc217a7 100644 --- a/mlir/docs/Bindings/Python.md +++ b/mlir/docs/Bindings/Python.md @@ -228,7 +228,7 @@ For example if `py_op1` and `py_op2` wrap the same operation under a root `py_op transformed such that the operation referenced (by `py_op1`, `py_op2`) is erased. Then `py_op1`, `py_op2` become "undefined" in a sense; manipulating them in any way is "formally forbidden". Note, this also applies to `SymbolTable` mutation, which is considered a transformation of the root `SymbolTable`-supporting operation for the -purposes of the discussion here. Methodically one can think of this similarly to how STL container iterators are invalidated once the container itself is changed. The "best practices" recommendation is to structure your code such that +purposes of the discussion here. Metaphorically, one can think of this similarly to how STL container iterators are invalidated once the container itself is changed. The "best practices" recommendation is to structure your code such that 1. First, query/manipulate various Python wrapper objects `py_op1`, `py_op2`, `py_op3`, etc.; 2. Second, Transform the AST/erase operations/etc. via a single root object; From 7cc8a50c03f2c139070358fb55db925b8b6ceb9d Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Mon, 1 Sep 2025 22:20:53 -0400 Subject: [PATCH 11/12] re-enable invalidating ops --- mlir/lib/Bindings/Python/IRCore.cpp | 10 ++++++++-- mlir/test/python/ir/symbol_table.py | 9 +++++++++ mlir/test/python/pass_manager.py | 21 +++++++++++++++++++++ 3 files changed, 38 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 8ab8901cdc41f..2df2a73fd88ff 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -3545,7 +3545,9 @@ void mlir::python::populateIRCore(nb::module_ &m) { [](PyOperationBase &self) { return PyOpSuccessors(self.getOperation().getRef()); }, - "Returns the list of Operation successors."); + "Returns the list of Operation successors.") + .def("_set_invalid", &PyOperation::setInvalid, + "Invalidate the operation."); auto opViewClass = nb::class_(m, "OpView") @@ -3589,7 +3591,11 @@ void mlir::python::populateIRCore(nb::module_ &m) { [](PyOperationBase &self) { return PyOpSuccessors(self.getOperation().getRef()); }, - "Returns the list of Operation successors."); + "Returns the list of Operation successors.") + .def( + "_set_invalid", + [](PyOpView &self) { self.getOperation().setInvalid(); }, + "Invalidate the operation."); opViewClass.attr("_ODS_REGIONS") = nb::make_tuple(0, true); opViewClass.attr("_ODS_OPERAND_SEGMENTS") = nb::none(); opViewClass.attr("_ODS_RESULT_SEGMENTS") = nb::none(); diff --git a/mlir/test/python/ir/symbol_table.py b/mlir/test/python/ir/symbol_table.py index 7afd539271d21..99d5fadfea10a 100644 --- a/mlir/test/python/ir/symbol_table.py +++ b/mlir/test/python/ir/symbol_table.py @@ -56,6 +56,15 @@ def testSymbolTableInsert(): print(m1) assert "bar" not in symbol_table + bar._set_invalid() + try: + print(bar) + except RuntimeError as e: + if "the operation has been invalidated" not in str(e): + raise + else: + assert False, "expected RuntimeError due to invalidated operation" + qux = m2.body.operations[0] m1.body.append(qux) symbol_table.insert(qux) diff --git a/mlir/test/python/pass_manager.py b/mlir/test/python/pass_manager.py index 0896cd9784641..5f92f5b52a09a 100644 --- a/mlir/test/python/pass_manager.py +++ b/mlir/test/python/pass_manager.py @@ -231,6 +231,27 @@ def testPostPassOpInvalidation(): PassManager.parse("builtin.module(canonicalize)").run(module) + func_op._set_invalid() + try: + log(func_op) + except RuntimeError as e: + # CHECK: the operation has been invalidated + log(e) + + outer_const_op._set_invalid() + try: + log(outer_const_op) + except RuntimeError as e: + # CHECK: the operation has been invalidated + log(e) + + inner_const_op._set_invalid() + try: + log(inner_const_op) + except RuntimeError as e: + # CHECK: the operation has been invalidated + log(e) + # CHECK: func.func @foo() { # CHECK: return # CHECK: } From e5e345b208cfcd54d2fbfee96fc01a74ddbbbba8 Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Mon, 1 Sep 2025 22:28:59 -0400 Subject: [PATCH 12/12] comments --- mlir/docs/Bindings/Python.md | 10 ++++++---- mlir/test/python/ir/module.py | 6 +++--- mlir/test/python/ir/operation.py | 2 +- 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/mlir/docs/Bindings/Python.md b/mlir/docs/Bindings/Python.md index 8067c5cc217a7..98ac635aa4ee2 100644 --- a/mlir/docs/Bindings/Python.md +++ b/mlir/docs/Bindings/Python.md @@ -231,11 +231,13 @@ become "undefined" in a sense; manipulating them in any way is "formally forbidd purposes of the discussion here. Metaphorically, one can think of this similarly to how STL container iterators are invalidated once the container itself is changed. The "best practices" recommendation is to structure your code such that 1. First, query/manipulate various Python wrapper objects `py_op1`, `py_op2`, `py_op3`, etc.; -2. Second, Transform the AST/erase operations/etc. via a single root object; -3. End. +2. Second, transform the AST/erase operations/etc. via a single root object; +3. Invalidate all queried nodes (e.g., using `op._set_invalid()`). -Ideally this should be done in a function body so that "End" corresponds to the end of the function and there are no -risks of Python wrapper objects leaking/living longer than necessary. +Ideally this should be done in a function body so that step (3) corresponds to the end of the function and there are no +risks of Python wrapper objects leaking/living longer than necessary. In summary, you should scope your changes based on +nesting i.e., change leaf nodes first before going up in hierarchy, and only in very rare cases query nested ops post +modifying a parent op. The C/C++ API allows for Region/Block to also be detached, but it simplifies the ownership model a lot to eliminate that possibility in this API, allowing the diff --git a/mlir/test/python/ir/module.py b/mlir/test/python/ir/module.py index a552eaa662af4..ad4c9340a6c82 100644 --- a/mlir/test/python/ir/module.py +++ b/mlir/test/python/ir/module.py @@ -127,7 +127,7 @@ def testModuleOperation(): # Ensure that operations are the same on multiple calls. op2 = module.operation - assert not op1 is op2 + assert op1 is not op2 assert op1 == op2 # Test live operation clearing. @@ -156,10 +156,10 @@ def testModuleCapsule(): module_capsule = module._CAPIPtr print(module_capsule) module_dup = Module._CAPICreate(module_capsule) - assert not module is module_dup + assert module is not module_dup assert module == module_dup module._clear_mlir_module() - assert not module == module_dup + assert module != module_dup assert module_dup.context is ctx # Gc and verify destructed. module = None diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py index 94f39c0fbd077..7759b1797e3c3 100644 --- a/mlir/test/python/ir/operation.py +++ b/mlir/test/python/ir/operation.py @@ -907,7 +907,7 @@ def testCapsuleConversions(): m_capsule = m._CAPIPtr assert '"mlir.ir.Operation._CAPIPtr"' in repr(m_capsule) m2 = Operation._CAPICreate(m_capsule) - assert not m2 is m + assert m2 is not m assert m2 == m # Gc and verify destructed. m = None