@@ -67,6 +67,12 @@ Returns a new MlirModule or raises an MLIRError if the parsing fails.
6767See also: https://mlir.llvm.org/docs/LangRef/
6868)" ;
6969
70+ static const char kModuleCAPICreate [] =
71+ R"( Creates a Module from a MlirModule wrapped by a capsule (i.e. module._CAPIPtr).
72+ Note this returns a new object BUT _clear_mlir_module(module) must be called to
73+ prevent double-frees (of the underlying mlir::Module).
74+ )" ;
75+
7076static const char kOperationCreateDocstring [] =
7177 R"( Creates a new operation.
7278
@@ -702,84 +708,6 @@ size_t PyMlirContext::getLiveCount() {
702708 return getLiveContexts ().size ();
703709}
704710
705- size_t PyMlirContext::getLiveOperationCount () {
706- nb::ft_lock_guard lock (liveOperationsMutex);
707- return liveOperations.size ();
708- }
709-
710- std::vector<PyOperation *> PyMlirContext::getLiveOperationObjects () {
711- std::vector<PyOperation *> liveObjects;
712- nb::ft_lock_guard lock (liveOperationsMutex);
713- for (auto &entry : liveOperations)
714- liveObjects.push_back (entry.second .second );
715- return liveObjects;
716- }
717-
718- size_t PyMlirContext::clearLiveOperations () {
719-
720- LiveOperationMap operations;
721- {
722- nb::ft_lock_guard lock (liveOperationsMutex);
723- std::swap (operations, liveOperations);
724- }
725- for (auto &op : operations)
726- op.second .second ->setInvalid ();
727- size_t numInvalidated = operations.size ();
728- return numInvalidated;
729- }
730-
731- void PyMlirContext::clearOperation (MlirOperation op) {
732- PyOperation *pyOp;
733- {
734- nb::ft_lock_guard lock (liveOperationsMutex);
735- auto it = liveOperations.find (op.ptr );
736- if (it == liveOperations.end ()) {
737- return ;
738- }
739- pyOp = it->second .second ;
740- liveOperations.erase (it);
741- }
742- pyOp->setInvalid ();
743- }
744-
745- void PyMlirContext::clearOperationsInside (PyOperationBase &op) {
746- using callBackData = struct {
747- PyOperation &rootOp;
748- bool rootSeen;
749- };
750- callBackData data{op.getOperation (), false };
751- // Mark all ops below the op that the passmanager will be rooted
752- // at (but not op itself - note the preorder) as invalid.
753- MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op,
754- void *userData) {
755- callBackData *data = static_cast <callBackData *>(userData);
756- if (LLVM_LIKELY (data->rootSeen ))
757- data->rootOp .getOperation ().getContext ()->clearOperation (op);
758- else
759- data->rootSeen = true ;
760- return MlirWalkResult::MlirWalkResultAdvance;
761- };
762- mlirOperationWalk (op.getOperation (), invalidatingCallback,
763- static_cast <void *>(&data), MlirWalkPreOrder);
764- }
765- void PyMlirContext::clearOperationsInside (MlirOperation op) {
766- PyOperationRef opRef = PyOperation::forOperation (getRef (), op);
767- clearOperationsInside (opRef->getOperation ());
768- }
769-
770- void PyMlirContext::clearOperationAndInside (PyOperationBase &op) {
771- MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op,
772- void *userData) {
773- PyMlirContextRef &contextRef = *static_cast <PyMlirContextRef *>(userData);
774- contextRef->clearOperation (op);
775- return MlirWalkResult::MlirWalkResultAdvance;
776- };
777- mlirOperationWalk (op.getOperation (), invalidatingCallback,
778- &op.getOperation ().getContext (), MlirWalkPreOrder);
779- }
780-
781- size_t PyMlirContext::getLiveModuleCount () { return liveModules.size (); }
782-
783711nb::object PyMlirContext::contextEnter (nb::object context) {
784712 return PyThreadContextEntry::pushContext (context);
785713}
@@ -1151,38 +1079,23 @@ PyLocation &DefaultingPyLocation::resolve() {
11511079PyModule::PyModule (PyMlirContextRef contextRef, MlirModule module )
11521080 : BaseContextObject(std::move(contextRef)), module (module ) {}
11531081
1154- PyModule::~PyModule () {
1155- nb::gil_scoped_acquire acquire;
1156- auto &liveModules = getContext ()->liveModules ;
1157- assert (liveModules.count (module .ptr ) == 1 &&
1158- " destroying module not in live map" );
1159- liveModules.erase (module .ptr );
1160- mlirModuleDestroy (module );
1161- }
1082+ PyModule::~PyModule () { mlirModuleDestroy (module ); }
11621083
11631084PyModuleRef PyModule::forModule (MlirModule module ) {
11641085 MlirContext context = mlirModuleGetContext (module );
11651086 PyMlirContextRef contextRef = PyMlirContext::forContext (context);
11661087
1167- nb::gil_scoped_acquire acquire;
1168- auto &liveModules = contextRef->liveModules ;
1169- auto it = liveModules.find (module .ptr );
1170- if (it == liveModules.end ()) {
1171- // Create.
1172- PyModule *unownedModule = new PyModule (std::move (contextRef), module );
1173- // Note that the default return value policy on cast is automatic_reference,
1174- // which does not take ownership (delete will not be called).
1175- // Just be explicit.
1176- nb::object pyRef = nb::cast (unownedModule, nb::rv_policy::take_ownership);
1177- unownedModule->handle = pyRef;
1178- liveModules[module .ptr ] =
1179- std::make_pair (unownedModule->handle , unownedModule);
1180- return PyModuleRef (unownedModule, std::move (pyRef));
1181- }
1182- // Use existing.
1183- PyModule *existing = it->second .second ;
1184- nb::object pyRef = nb::borrow<nb::object>(it->second .first );
1185- return PyModuleRef (existing, std::move (pyRef));
1088+ // Create.
1089+ PyModule *unownedModule = new PyModule (std::move (contextRef), module );
1090+ // Note that the default return value policy on cast is `automatic_reference`,
1091+ // which means "does not take ownership, does not call delete/dtor".
1092+ // We use `take_ownership`, which means "Python will call the C++ destructor
1093+ // and delete operator when the Python wrapper is garbage collected", because
1094+ // MlirModule actually wraps OwningOpRef<ModuleOp> (see mlirModuleCreateParse
1095+ // etc).
1096+ nb::object pyRef = nb::cast (unownedModule, nb::rv_policy::take_ownership);
1097+ unownedModule->handle = pyRef;
1098+ return PyModuleRef (unownedModule, std::move (pyRef));
11861099}
11871100
11881101nb::object PyModule::createFromCapsule (nb::object capsule) {
@@ -1207,15 +1120,11 @@ PyOperation::~PyOperation() {
12071120 // If the operation has already been invalidated there is nothing to do.
12081121 if (!valid)
12091122 return ;
1210-
1211- // Otherwise, invalidate the operation and remove it from live map when it is
1212- // attached.
1213- if (isAttached ()) {
1214- getContext ()->clearOperation (*this );
1215- } else {
1216- // And destroy it when it is detached, i.e. owned by Python, in which case
1217- // all nested operations must be invalidated at removed from the live map as
1218- // well.
1123+ // Otherwise, invalidate the operation when it is attached.
1124+ if (isAttached ())
1125+ setInvalid ();
1126+ else {
1127+ // And destroy it when it is detached, i.e. owned by Python.
12191128 erase ();
12201129 }
12211130}
@@ -1252,35 +1161,15 @@ PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
12521161PyOperationRef PyOperation::forOperation (PyMlirContextRef contextRef,
12531162 MlirOperation operation,
12541163 nb::object parentKeepAlive) {
1255- nb::ft_lock_guard lock (contextRef->liveOperationsMutex );
1256- auto &liveOperations = contextRef->liveOperations ;
1257- auto it = liveOperations.find (operation.ptr );
1258- if (it == liveOperations.end ()) {
1259- // Create.
1260- PyOperationRef result = createInstance (std::move (contextRef), operation,
1261- std::move (parentKeepAlive));
1262- liveOperations[operation.ptr ] =
1263- std::make_pair (result.getObject (), result.get ());
1264- return result;
1265- }
1266- // Use existing.
1267- PyOperation *existing = it->second .second ;
1268- nb::object pyRef = nb::borrow<nb::object>(it->second .first );
1269- return PyOperationRef (existing, std::move (pyRef));
1164+ return createInstance (std::move (contextRef), operation,
1165+ std::move (parentKeepAlive));
12701166}
12711167
12721168PyOperationRef PyOperation::createDetached (PyMlirContextRef contextRef,
12731169 MlirOperation operation,
12741170 nb::object parentKeepAlive) {
1275- nb::ft_lock_guard lock (contextRef->liveOperationsMutex );
1276- auto &liveOperations = contextRef->liveOperations ;
1277- assert (liveOperations.count (operation.ptr ) == 0 &&
1278- " cannot create detached operation that already exists" );
1279- (void )liveOperations;
12801171 PyOperationRef created = createInstance (std::move (contextRef), operation,
12811172 std::move (parentKeepAlive));
1282- liveOperations[operation.ptr ] =
1283- std::make_pair (created.getObject (), created.get ());
12841173 created->attached = false ;
12851174 return created;
12861175}
@@ -1652,7 +1541,7 @@ nb::object PyOperation::createOpView() {
16521541
16531542void PyOperation::erase () {
16541543 checkValid ();
1655- getContext ()-> clearOperationAndInside (* this );
1544+ setInvalid ( );
16561545 mlirOperationDestroy (operation);
16571546}
16581547
@@ -3023,14 +2912,6 @@ void mlir::python::populateIRCore(nb::module_ &m) {
30232912 PyMlirContextRef ref = PyMlirContext::forContext (self.get ());
30242913 return ref.releaseObject ();
30252914 })
3026- .def (" _get_live_operation_count" , &PyMlirContext::getLiveOperationCount)
3027- .def (" _get_live_operation_objects" ,
3028- &PyMlirContext::getLiveOperationObjects)
3029- .def (" _clear_live_operations" , &PyMlirContext::clearLiveOperations)
3030- .def (" _clear_live_operations_inside" ,
3031- nb::overload_cast<MlirOperation>(
3032- &PyMlirContext::clearOperationsInside))
3033- .def (" _get_live_module_count" , &PyMlirContext::getLiveModuleCount)
30342915 .def_prop_ro (MLIR_PYTHON_CAPI_PTR_ATTR, &PyMlirContext::getCapsule)
30352916 .def (MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
30362917 .def (" __enter__" , &PyMlirContext::contextEnter)
@@ -3348,7 +3229,9 @@ void mlir::python::populateIRCore(nb::module_ &m) {
33483229 // ----------------------------------------------------------------------------
33493230 nb::class_<PyModule>(m, " Module" , nb::is_weak_referenceable ())
33503231 .def_prop_ro (MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule)
3351- .def (MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule)
3232+ .def (MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule,
3233+ kModuleCAPICreate )
3234+ .def (" _clear_mlir_module" , &PyModule::clearMlirModule)
33523235 .def_static (
33533236 " parse" ,
33543237 [](const std::string &moduleAsm, DefaultingPyMlirContext context) {
@@ -3428,7 +3311,13 @@ void mlir::python::populateIRCore(nb::module_ &m) {
34283311 // Defer to the operation's __str__.
34293312 return self.attr (" operation" ).attr (" __str__" )();
34303313 },
3431- kOperationStrDunderDocstring );
3314+ kOperationStrDunderDocstring )
3315+ .def (
3316+ " __eq__" ,
3317+ [](PyModule &self, PyModule &other) {
3318+ return mlirModuleEqual (self.get (), other.get ());
3319+ },
3320+ " other" _a);
34323321
34333322 // ----------------------------------------------------------------------------
34343323 // Mapping of Operation.
@@ -3440,7 +3329,8 @@ void mlir::python::populateIRCore(nb::module_ &m) {
34403329 })
34413330 .def (" __eq__" ,
34423331 [](PyOperationBase &self, PyOperationBase &other) {
3443- return &self.getOperation () == &other.getOperation ();
3332+ return mlirOperationEqual (self.getOperation ().get (),
3333+ other.getOperation ().get ());
34443334 })
34453335 .def (" __eq__" ,
34463336 [](PyOperationBase &self, nb::object other) { return false ; })
@@ -3655,7 +3545,9 @@ void mlir::python::populateIRCore(nb::module_ &m) {
36553545 [](PyOperationBase &self) {
36563546 return PyOpSuccessors (self.getOperation ().getRef ());
36573547 },
3658- " Returns the list of Operation successors." );
3548+ " Returns the list of Operation successors." )
3549+ .def (" _set_invalid" , &PyOperation::setInvalid,
3550+ " Invalidate the operation." );
36593551
36603552 auto opViewClass =
36613553 nb::class_<PyOpView, PyOperationBase>(m, " OpView" )
@@ -3699,7 +3591,11 @@ void mlir::python::populateIRCore(nb::module_ &m) {
36993591 [](PyOperationBase &self) {
37003592 return PyOpSuccessors (self.getOperation ().getRef ());
37013593 },
3702- " Returns the list of Operation successors." );
3594+ " Returns the list of Operation successors." )
3595+ .def (
3596+ " _set_invalid" ,
3597+ [](PyOpView &self) { self.getOperation ().setInvalid (); },
3598+ " Invalidate the operation." );
37033599 opViewClass.attr (" _ODS_REGIONS" ) = nb::make_tuple (0 , true );
37043600 opViewClass.attr (" _ODS_OPERAND_SEGMENTS" ) = nb::none ();
37053601 opViewClass.attr (" _ODS_RESULT_SEGMENTS" ) = nb::none ();
0 commit comments