@@ -702,84 +702,6 @@ size_t PyMlirContext::getLiveCount() {
702
702
return getLiveContexts ().size ();
703
703
}
704
704
705
- size_t PyMlirContext::getLiveOperationCount () {
706
- nb::ft_lock_guard lock (liveOperationsMutex);
707
- return liveOperations.size ();
708
- }
709
-
710
- std::vector<PyOperation *> PyMlirContext::getLiveOperationObjects () {
711
- std::vector<PyOperation *> liveObjects;
712
- nb::ft_lock_guard lock (liveOperationsMutex);
713
- for (auto &entry : liveOperations)
714
- liveObjects.push_back (entry.second .second );
715
- return liveObjects;
716
- }
717
-
718
- size_t PyMlirContext::clearLiveOperations () {
719
-
720
- LiveOperationMap operations;
721
- {
722
- nb::ft_lock_guard lock (liveOperationsMutex);
723
- std::swap (operations, liveOperations);
724
- }
725
- for (auto &op : operations)
726
- op.second .second ->setInvalid ();
727
- size_t numInvalidated = operations.size ();
728
- return numInvalidated;
729
- }
730
-
731
- void PyMlirContext::clearOperation (MlirOperation op) {
732
- PyOperation *py_op;
733
- {
734
- nb::ft_lock_guard lock (liveOperationsMutex);
735
- auto it = liveOperations.find (op.ptr );
736
- if (it == liveOperations.end ()) {
737
- return ;
738
- }
739
- py_op = it->second .second ;
740
- liveOperations.erase (it);
741
- }
742
- py_op->setInvalid ();
743
- }
744
-
745
- void PyMlirContext::clearOperationsInside (PyOperationBase &op) {
746
- typedef struct {
747
- PyOperation &rootOp;
748
- bool rootSeen;
749
- } callBackData;
750
- callBackData data{op.getOperation (), false };
751
- // Mark all ops below the op that the passmanager will be rooted
752
- // at (but not op itself - note the preorder) as invalid.
753
- MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op,
754
- void *userData) {
755
- callBackData *data = static_cast <callBackData *>(userData);
756
- if (LLVM_LIKELY (data->rootSeen ))
757
- data->rootOp .getOperation ().getContext ()->clearOperation (op);
758
- else
759
- data->rootSeen = true ;
760
- return MlirWalkResult::MlirWalkResultAdvance;
761
- };
762
- mlirOperationWalk (op.getOperation (), invalidatingCallback,
763
- static_cast <void *>(&data), MlirWalkPreOrder);
764
- }
765
- void PyMlirContext::clearOperationsInside (MlirOperation op) {
766
- PyOperationRef opRef = PyOperation::forOperation (getRef (), op);
767
- clearOperationsInside (opRef->getOperation ());
768
- }
769
-
770
- void PyMlirContext::clearOperationAndInside (PyOperationBase &op) {
771
- MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op,
772
- void *userData) {
773
- PyMlirContextRef &contextRef = *static_cast <PyMlirContextRef *>(userData);
774
- contextRef->clearOperation (op);
775
- return MlirWalkResult::MlirWalkResultAdvance;
776
- };
777
- mlirOperationWalk (op.getOperation (), invalidatingCallback,
778
- &op.getOperation ().getContext (), MlirWalkPreOrder);
779
- }
780
-
781
- size_t PyMlirContext::getLiveModuleCount () { return liveModules.size (); }
782
-
783
705
nb::object PyMlirContext::contextEnter (nb::object context) {
784
706
return PyThreadContextEntry::pushContext (context);
785
707
}
@@ -1151,38 +1073,20 @@ PyLocation &DefaultingPyLocation::resolve() {
1151
1073
PyModule::PyModule (PyMlirContextRef contextRef, MlirModule module )
1152
1074
: BaseContextObject(std::move(contextRef)), module (module ) {}
1153
1075
1154
- PyModule::~PyModule () {
1155
- nb::gil_scoped_acquire acquire;
1156
- auto &liveModules = getContext ()->liveModules ;
1157
- assert (liveModules.count (module .ptr ) == 1 &&
1158
- " destroying module not in live map" );
1159
- liveModules.erase (module .ptr );
1160
- mlirModuleDestroy (module );
1161
- }
1076
+ PyModule::~PyModule () { mlirModuleDestroy (module ); }
1162
1077
1163
1078
PyModuleRef PyModule::forModule (MlirModule module ) {
1164
1079
MlirContext context = mlirModuleGetContext (module );
1165
1080
PyMlirContextRef contextRef = PyMlirContext::forContext (context);
1166
1081
1167
- nb::gil_scoped_acquire acquire;
1168
- auto &liveModules = contextRef->liveModules ;
1169
- auto it = liveModules.find (module .ptr );
1170
- if (it == liveModules.end ()) {
1171
- // Create.
1172
- PyModule *unownedModule = new PyModule (std::move (contextRef), module );
1173
- // Note that the default return value policy on cast is automatic_reference,
1174
- // which does not take ownership (delete will not be called).
1175
- // Just be explicit.
1176
- nb::object pyRef = nb::cast (unownedModule, nb::rv_policy::take_ownership);
1177
- unownedModule->handle = pyRef;
1178
- liveModules[module .ptr ] =
1179
- std::make_pair (unownedModule->handle , unownedModule);
1180
- return PyModuleRef (unownedModule, std::move (pyRef));
1181
- }
1182
- // Use existing.
1183
- PyModule *existing = it->second .second ;
1184
- nb::object pyRef = nb::borrow<nb::object>(it->second .first );
1185
- return PyModuleRef (existing, std::move (pyRef));
1082
+ // Create.
1083
+ PyModule *unownedModule = new PyModule (std::move (contextRef), module );
1084
+ // Note that the default return value policy on cast is automatic_reference,
1085
+ // which does not take ownership (delete will not be called).
1086
+ // Just be explicit.
1087
+ nb::object pyRef = nb::cast (unownedModule, nb::rv_policy::take_ownership);
1088
+ unownedModule->handle = pyRef;
1089
+ return PyModuleRef (unownedModule, std::move (pyRef));
1186
1090
}
1187
1091
1188
1092
nb::object PyModule::createFromCapsule (nb::object capsule) {
@@ -1207,16 +1111,8 @@ PyOperation::~PyOperation() {
1207
1111
// If the operation has already been invalidated there is nothing to do.
1208
1112
if (!valid)
1209
1113
return ;
1210
-
1211
- // Otherwise, invalidate the operation and remove it from live map when it is
1212
- // attached.
1213
- if (isAttached ()) {
1214
- getContext ()->clearOperation (*this );
1215
- } else {
1216
- // And destroy it when it is detached, i.e. owned by Python, in which case
1217
- // all nested operations must be invalidated at removed from the live map as
1218
- // well.
1219
- erase ();
1114
+ if (!isAttached ()) {
1115
+ mlirOperationDestroy (operation);
1220
1116
}
1221
1117
}
1222
1118
@@ -1246,41 +1142,22 @@ PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
1246
1142
if (parentKeepAlive) {
1247
1143
unownedOperation->parentKeepAlive = std::move (parentKeepAlive);
1248
1144
}
1249
- return unownedOperation;
1145
+ return PyOperationRef ( unownedOperation, std::move (pyRef)) ;
1250
1146
}
1251
1147
1252
1148
PyOperationRef PyOperation::forOperation (PyMlirContextRef contextRef,
1253
1149
MlirOperation operation,
1254
1150
nb::object parentKeepAlive) {
1255
- nb::ft_lock_guard lock (contextRef->liveOperationsMutex );
1256
- auto &liveOperations = contextRef->liveOperations ;
1257
- auto it = liveOperations.find (operation.ptr );
1258
- if (it == liveOperations.end ()) {
1259
- // Create.
1260
- PyOperationRef result = createInstance (std::move (contextRef), operation,
1261
- std::move (parentKeepAlive));
1262
- liveOperations[operation.ptr ] =
1263
- std::make_pair (result.getObject (), result.get ());
1264
- return result;
1265
- }
1266
- // Use existing.
1267
- PyOperation *existing = it->second .second ;
1268
- nb::object pyRef = nb::borrow<nb::object>(it->second .first );
1269
- return PyOperationRef (existing, std::move (pyRef));
1151
+ // Create.
1152
+ return createInstance (std::move (contextRef), operation,
1153
+ std::move (parentKeepAlive));
1270
1154
}
1271
1155
1272
1156
PyOperationRef PyOperation::createDetached (PyMlirContextRef contextRef,
1273
1157
MlirOperation operation,
1274
1158
nb::object parentKeepAlive) {
1275
- nb::ft_lock_guard lock (contextRef->liveOperationsMutex );
1276
- auto &liveOperations = contextRef->liveOperations ;
1277
- assert (liveOperations.count (operation.ptr ) == 0 &&
1278
- " cannot create detached operation that already exists" );
1279
- (void )liveOperations;
1280
1159
PyOperationRef created = createInstance (std::move (contextRef), operation,
1281
1160
std::move (parentKeepAlive));
1282
- liveOperations[operation.ptr ] =
1283
- std::make_pair (created.getObject (), created.get ());
1284
1161
created->attached = false ;
1285
1162
return created;
1286
1163
}
@@ -1652,7 +1529,6 @@ nb::object PyOperation::createOpView() {
1652
1529
1653
1530
void PyOperation::erase () {
1654
1531
checkValid ();
1655
- getContext ()->clearOperationAndInside (*this );
1656
1532
mlirOperationDestroy (operation);
1657
1533
}
1658
1534
@@ -2494,7 +2370,6 @@ class PyBlockArgumentList
2494
2370
: public Sliceable<PyBlockArgumentList, PyBlockArgument> {
2495
2371
public:
2496
2372
static constexpr const char *pyClassName = " BlockArgumentList" ;
2497
- using SliceableT = Sliceable<PyBlockArgumentList, PyBlockArgument>;
2498
2373
2499
2374
PyBlockArgumentList (PyOperationRef operation, MlirBlock block,
2500
2375
intptr_t startIndex = 0 , intptr_t length = -1 ,
@@ -3023,14 +2898,6 @@ void mlir::python::populateIRCore(nb::module_ &m) {
3023
2898
PyMlirContextRef ref = PyMlirContext::forContext (self.get ());
3024
2899
return ref.releaseObject ();
3025
2900
})
3026
- .def (" _get_live_operation_count" , &PyMlirContext::getLiveOperationCount)
3027
- .def (" _get_live_operation_objects" ,
3028
- &PyMlirContext::getLiveOperationObjects)
3029
- .def (" _clear_live_operations" , &PyMlirContext::clearLiveOperations)
3030
- .def (" _clear_live_operations_inside" ,
3031
- nb::overload_cast<MlirOperation>(
3032
- &PyMlirContext::clearOperationsInside))
3033
- .def (" _get_live_module_count" , &PyMlirContext::getLiveModuleCount)
3034
2901
.def_prop_ro (MLIR_PYTHON_CAPI_PTR_ATTR, &PyMlirContext::getCapsule)
3035
2902
.def (MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
3036
2903
.def (" __enter__" , &PyMlirContext::contextEnter)
@@ -3428,7 +3295,13 @@ void mlir::python::populateIRCore(nb::module_ &m) {
3428
3295
// Defer to the operation's __str__.
3429
3296
return self.attr (" operation" ).attr (" __str__" )();
3430
3297
},
3431
- kOperationStrDunderDocstring );
3298
+ kOperationStrDunderDocstring )
3299
+ .def (
3300
+ " __eq__" ,
3301
+ [](PyModule &self, PyModule &other) {
3302
+ return mlirModuleEqual (self.get (), other.get ());
3303
+ },
3304
+ " other" _a);
3432
3305
3433
3306
// ----------------------------------------------------------------------------
3434
3307
// Mapping of Operation.
@@ -3440,7 +3313,8 @@ void mlir::python::populateIRCore(nb::module_ &m) {
3440
3313
})
3441
3314
.def (" __eq__" ,
3442
3315
[](PyOperationBase &self, PyOperationBase &other) {
3443
- return &self.getOperation () == &other.getOperation ();
3316
+ return mlirOperationEqual (self.getOperation ().get (),
3317
+ other.getOperation ().get ());
3444
3318
})
3445
3319
.def (" __eq__" ,
3446
3320
[](PyOperationBase &self, nb::object other) { return false ; })
0 commit comments