@@ -67,6 +67,12 @@ Returns a new MlirModule or raises an MLIRError if the parsing fails.
67
67
See also: https://mlir.llvm.org/docs/LangRef/
68
68
)" ;
69
69
70
+ static const char kModuleCAPICreate [] =
71
+ R"( Creates a Module from a MlirModule wrapped by a capsule (i.e. module._CAPIPtr).
72
+ Note this returns a new object BUT _clear_mlir_module(module) must be called to
73
+ prevent double-frees (of the underlying mlir::Module).
74
+ )" ;
75
+
70
76
static const char kOperationCreateDocstring [] =
71
77
R"( Creates a new operation.
72
78
@@ -702,84 +708,6 @@ size_t PyMlirContext::getLiveCount() {
702
708
return getLiveContexts ().size ();
703
709
}
704
710
705
- size_t PyMlirContext::getLiveOperationCount () {
706
- nb::ft_lock_guard lock (liveOperationsMutex);
707
- return liveOperations.size ();
708
- }
709
-
710
- std::vector<PyOperation *> PyMlirContext::getLiveOperationObjects () {
711
- std::vector<PyOperation *> liveObjects;
712
- nb::ft_lock_guard lock (liveOperationsMutex);
713
- for (auto &entry : liveOperations)
714
- liveObjects.push_back (entry.second .second );
715
- return liveObjects;
716
- }
717
-
718
- size_t PyMlirContext::clearLiveOperations () {
719
-
720
- LiveOperationMap operations;
721
- {
722
- nb::ft_lock_guard lock (liveOperationsMutex);
723
- std::swap (operations, liveOperations);
724
- }
725
- for (auto &op : operations)
726
- op.second .second ->setInvalid ();
727
- size_t numInvalidated = operations.size ();
728
- return numInvalidated;
729
- }
730
-
731
- void PyMlirContext::clearOperation (MlirOperation op) {
732
- PyOperation *pyOp;
733
- {
734
- nb::ft_lock_guard lock (liveOperationsMutex);
735
- auto it = liveOperations.find (op.ptr );
736
- if (it == liveOperations.end ()) {
737
- return ;
738
- }
739
- pyOp = it->second .second ;
740
- liveOperations.erase (it);
741
- }
742
- pyOp->setInvalid ();
743
- }
744
-
745
- void PyMlirContext::clearOperationsInside (PyOperationBase &op) {
746
- using callBackData = struct {
747
- PyOperation &rootOp;
748
- bool rootSeen;
749
- };
750
- callBackData data{op.getOperation (), false };
751
- // Mark all ops below the op that the passmanager will be rooted
752
- // at (but not op itself - note the preorder) as invalid.
753
- MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op,
754
- void *userData) {
755
- callBackData *data = static_cast <callBackData *>(userData);
756
- if (LLVM_LIKELY (data->rootSeen ))
757
- data->rootOp .getOperation ().getContext ()->clearOperation (op);
758
- else
759
- data->rootSeen = true ;
760
- return MlirWalkResult::MlirWalkResultAdvance;
761
- };
762
- mlirOperationWalk (op.getOperation (), invalidatingCallback,
763
- static_cast <void *>(&data), MlirWalkPreOrder);
764
- }
765
- void PyMlirContext::clearOperationsInside (MlirOperation op) {
766
- PyOperationRef opRef = PyOperation::forOperation (getRef (), op);
767
- clearOperationsInside (opRef->getOperation ());
768
- }
769
-
770
- void PyMlirContext::clearOperationAndInside (PyOperationBase &op) {
771
- MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op,
772
- void *userData) {
773
- PyMlirContextRef &contextRef = *static_cast <PyMlirContextRef *>(userData);
774
- contextRef->clearOperation (op);
775
- return MlirWalkResult::MlirWalkResultAdvance;
776
- };
777
- mlirOperationWalk (op.getOperation (), invalidatingCallback,
778
- &op.getOperation ().getContext (), MlirWalkPreOrder);
779
- }
780
-
781
- size_t PyMlirContext::getLiveModuleCount () { return liveModules.size (); }
782
-
783
711
nb::object PyMlirContext::contextEnter (nb::object context) {
784
712
return PyThreadContextEntry::pushContext (context);
785
713
}
@@ -1151,38 +1079,23 @@ PyLocation &DefaultingPyLocation::resolve() {
1151
1079
PyModule::PyModule (PyMlirContextRef contextRef, MlirModule module )
1152
1080
: BaseContextObject(std::move(contextRef)), module (module ) {}
1153
1081
1154
- PyModule::~PyModule () {
1155
- nb::gil_scoped_acquire acquire;
1156
- auto &liveModules = getContext ()->liveModules ;
1157
- assert (liveModules.count (module .ptr ) == 1 &&
1158
- " destroying module not in live map" );
1159
- liveModules.erase (module .ptr );
1160
- mlirModuleDestroy (module );
1161
- }
1082
+ PyModule::~PyModule () { mlirModuleDestroy (module ); }
1162
1083
1163
1084
PyModuleRef PyModule::forModule (MlirModule module ) {
1164
1085
MlirContext context = mlirModuleGetContext (module );
1165
1086
PyMlirContextRef contextRef = PyMlirContext::forContext (context);
1166
1087
1167
- nb::gil_scoped_acquire acquire;
1168
- auto &liveModules = contextRef->liveModules ;
1169
- auto it = liveModules.find (module .ptr );
1170
- if (it == liveModules.end ()) {
1171
- // Create.
1172
- PyModule *unownedModule = new PyModule (std::move (contextRef), module );
1173
- // Note that the default return value policy on cast is automatic_reference,
1174
- // which does not take ownership (delete will not be called).
1175
- // Just be explicit.
1176
- nb::object pyRef = nb::cast (unownedModule, nb::rv_policy::take_ownership);
1177
- unownedModule->handle = pyRef;
1178
- liveModules[module .ptr ] =
1179
- std::make_pair (unownedModule->handle , unownedModule);
1180
- return PyModuleRef (unownedModule, std::move (pyRef));
1181
- }
1182
- // Use existing.
1183
- PyModule *existing = it->second .second ;
1184
- nb::object pyRef = nb::borrow<nb::object>(it->second .first );
1185
- return PyModuleRef (existing, std::move (pyRef));
1088
+ // Create.
1089
+ PyModule *unownedModule = new PyModule (std::move (contextRef), module );
1090
+ // Note that the default return value policy on cast is `automatic_reference`,
1091
+ // which means "does not take ownership, does not call delete/dtor".
1092
+ // We use `take_ownership`, which means "Python will call the C++ destructor
1093
+ // and delete operator when the Python wrapper is garbage collected", because
1094
+ // MlirModule actually wraps OwningOpRef<ModuleOp> (see mlirModuleCreateParse
1095
+ // etc).
1096
+ nb::object pyRef = nb::cast (unownedModule, nb::rv_policy::take_ownership);
1097
+ unownedModule->handle = pyRef;
1098
+ return PyModuleRef (unownedModule, std::move (pyRef));
1186
1099
}
1187
1100
1188
1101
nb::object PyModule::createFromCapsule (nb::object capsule) {
@@ -1207,15 +1120,11 @@ PyOperation::~PyOperation() {
1207
1120
// If the operation has already been invalidated there is nothing to do.
1208
1121
if (!valid)
1209
1122
return ;
1210
-
1211
- // Otherwise, invalidate the operation and remove it from live map when it is
1212
- // attached.
1213
- if (isAttached ()) {
1214
- getContext ()->clearOperation (*this );
1215
- } else {
1216
- // And destroy it when it is detached, i.e. owned by Python, in which case
1217
- // all nested operations must be invalidated at removed from the live map as
1218
- // well.
1123
+ // Otherwise, invalidate the operation when it is attached.
1124
+ if (isAttached ())
1125
+ setInvalid ();
1126
+ else {
1127
+ // And destroy it when it is detached, i.e. owned by Python.
1219
1128
erase ();
1220
1129
}
1221
1130
}
@@ -1252,35 +1161,15 @@ PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
1252
1161
PyOperationRef PyOperation::forOperation (PyMlirContextRef contextRef,
1253
1162
MlirOperation operation,
1254
1163
nb::object parentKeepAlive) {
1255
- nb::ft_lock_guard lock (contextRef->liveOperationsMutex );
1256
- auto &liveOperations = contextRef->liveOperations ;
1257
- auto it = liveOperations.find (operation.ptr );
1258
- if (it == liveOperations.end ()) {
1259
- // Create.
1260
- PyOperationRef result = createInstance (std::move (contextRef), operation,
1261
- std::move (parentKeepAlive));
1262
- liveOperations[operation.ptr ] =
1263
- std::make_pair (result.getObject (), result.get ());
1264
- return result;
1265
- }
1266
- // Use existing.
1267
- PyOperation *existing = it->second .second ;
1268
- nb::object pyRef = nb::borrow<nb::object>(it->second .first );
1269
- return PyOperationRef (existing, std::move (pyRef));
1164
+ return createInstance (std::move (contextRef), operation,
1165
+ std::move (parentKeepAlive));
1270
1166
}
1271
1167
1272
1168
PyOperationRef PyOperation::createDetached (PyMlirContextRef contextRef,
1273
1169
MlirOperation operation,
1274
1170
nb::object parentKeepAlive) {
1275
- nb::ft_lock_guard lock (contextRef->liveOperationsMutex );
1276
- auto &liveOperations = contextRef->liveOperations ;
1277
- assert (liveOperations.count (operation.ptr ) == 0 &&
1278
- " cannot create detached operation that already exists" );
1279
- (void )liveOperations;
1280
1171
PyOperationRef created = createInstance (std::move (contextRef), operation,
1281
1172
std::move (parentKeepAlive));
1282
- liveOperations[operation.ptr ] =
1283
- std::make_pair (created.getObject (), created.get ());
1284
1173
created->attached = false ;
1285
1174
return created;
1286
1175
}
@@ -1652,7 +1541,7 @@ nb::object PyOperation::createOpView() {
1652
1541
1653
1542
void PyOperation::erase () {
1654
1543
checkValid ();
1655
- getContext ()-> clearOperationAndInside (* this );
1544
+ setInvalid ( );
1656
1545
mlirOperationDestroy (operation);
1657
1546
}
1658
1547
@@ -3023,14 +2912,6 @@ void mlir::python::populateIRCore(nb::module_ &m) {
3023
2912
PyMlirContextRef ref = PyMlirContext::forContext (self.get ());
3024
2913
return ref.releaseObject ();
3025
2914
})
3026
- .def (" _get_live_operation_count" , &PyMlirContext::getLiveOperationCount)
3027
- .def (" _get_live_operation_objects" ,
3028
- &PyMlirContext::getLiveOperationObjects)
3029
- .def (" _clear_live_operations" , &PyMlirContext::clearLiveOperations)
3030
- .def (" _clear_live_operations_inside" ,
3031
- nb::overload_cast<MlirOperation>(
3032
- &PyMlirContext::clearOperationsInside))
3033
- .def (" _get_live_module_count" , &PyMlirContext::getLiveModuleCount)
3034
2915
.def_prop_ro (MLIR_PYTHON_CAPI_PTR_ATTR, &PyMlirContext::getCapsule)
3035
2916
.def (MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
3036
2917
.def (" __enter__" , &PyMlirContext::contextEnter)
@@ -3348,7 +3229,9 @@ void mlir::python::populateIRCore(nb::module_ &m) {
3348
3229
// ----------------------------------------------------------------------------
3349
3230
nb::class_<PyModule>(m, " Module" , nb::is_weak_referenceable ())
3350
3231
.def_prop_ro (MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule)
3351
- .def (MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule)
3232
+ .def (MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule,
3233
+ kModuleCAPICreate )
3234
+ .def (" _clear_mlir_module" , &PyModule::clearMlirModule)
3352
3235
.def_static (
3353
3236
" parse" ,
3354
3237
[](const std::string &moduleAsm, DefaultingPyMlirContext context) {
@@ -3428,7 +3311,13 @@ void mlir::python::populateIRCore(nb::module_ &m) {
3428
3311
// Defer to the operation's __str__.
3429
3312
return self.attr (" operation" ).attr (" __str__" )();
3430
3313
},
3431
- kOperationStrDunderDocstring );
3314
+ kOperationStrDunderDocstring )
3315
+ .def (
3316
+ " __eq__" ,
3317
+ [](PyModule &self, PyModule &other) {
3318
+ return mlirModuleEqual (self.get (), other.get ());
3319
+ },
3320
+ " other" _a);
3432
3321
3433
3322
// ----------------------------------------------------------------------------
3434
3323
// Mapping of Operation.
@@ -3440,7 +3329,8 @@ void mlir::python::populateIRCore(nb::module_ &m) {
3440
3329
})
3441
3330
.def (" __eq__" ,
3442
3331
[](PyOperationBase &self, PyOperationBase &other) {
3443
- return &self.getOperation () == &other.getOperation ();
3332
+ return mlirOperationEqual (self.getOperation ().get (),
3333
+ other.getOperation ().get ());
3444
3334
})
3445
3335
.def (" __eq__" ,
3446
3336
[](PyOperationBase &self, nb::object other) { return false ; })
@@ -3655,7 +3545,9 @@ void mlir::python::populateIRCore(nb::module_ &m) {
3655
3545
[](PyOperationBase &self) {
3656
3546
return PyOpSuccessors (self.getOperation ().getRef ());
3657
3547
},
3658
- " Returns the list of Operation successors." );
3548
+ " Returns the list of Operation successors." )
3549
+ .def (" _set_invalid" , &PyOperation::setInvalid,
3550
+ " Invalidate the operation." );
3659
3551
3660
3552
auto opViewClass =
3661
3553
nb::class_<PyOpView, PyOperationBase>(m, " OpView" )
@@ -3699,7 +3591,11 @@ void mlir::python::populateIRCore(nb::module_ &m) {
3699
3591
[](PyOperationBase &self) {
3700
3592
return PyOpSuccessors (self.getOperation ().getRef ());
3701
3593
},
3702
- " Returns the list of Operation successors." );
3594
+ " Returns the list of Operation successors." )
3595
+ .def (
3596
+ " _set_invalid" ,
3597
+ [](PyOpView &self) { self.getOperation ().setInvalid (); },
3598
+ " Invalidate the operation." );
3703
3599
opViewClass.attr (" _ODS_REGIONS" ) = nb::make_tuple (0 , true );
3704
3600
opViewClass.attr (" _ODS_OPERAND_SEGMENTS" ) = nb::none ();
3705
3601
opViewClass.attr (" _ODS_RESULT_SEGMENTS" ) = nb::none ();
0 commit comments