Skip to content

Commit c357923

Browse files
committed
[mlir][python] wip remove liveOpeartions
1 parent 03a23f0 commit c357923

File tree

11 files changed

+176
-263
lines changed

11 files changed

+176
-263
lines changed

mlir/include/mlir-c/IR.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -415,6 +415,8 @@ MLIR_CAPI_EXPORTED MlirOperation mlirModuleGetOperation(MlirModule module);
415415
/// The returned module is null when the input operation was not a ModuleOp.
416416
MLIR_CAPI_EXPORTED MlirModule mlirModuleFromOperation(MlirOperation op);
417417

418+
MLIR_CAPI_EXPORTED bool mlirModuleEqual(MlirModule mod, MlirModule other);
419+
418420
//===----------------------------------------------------------------------===//
419421
// Operation state.
420422
//===----------------------------------------------------------------------===//

mlir/lib/Bindings/Python/IRCore.cpp

Lines changed: 24 additions & 150 deletions
Original file line numberDiff line numberDiff line change
@@ -702,84 +702,6 @@ size_t PyMlirContext::getLiveCount() {
702702
return getLiveContexts().size();
703703
}
704704

705-
size_t PyMlirContext::getLiveOperationCount() {
706-
nb::ft_lock_guard lock(liveOperationsMutex);
707-
return liveOperations.size();
708-
}
709-
710-
std::vector<PyOperation *> PyMlirContext::getLiveOperationObjects() {
711-
std::vector<PyOperation *> liveObjects;
712-
nb::ft_lock_guard lock(liveOperationsMutex);
713-
for (auto &entry : liveOperations)
714-
liveObjects.push_back(entry.second.second);
715-
return liveObjects;
716-
}
717-
718-
size_t PyMlirContext::clearLiveOperations() {
719-
720-
LiveOperationMap operations;
721-
{
722-
nb::ft_lock_guard lock(liveOperationsMutex);
723-
std::swap(operations, liveOperations);
724-
}
725-
for (auto &op : operations)
726-
op.second.second->setInvalid();
727-
size_t numInvalidated = operations.size();
728-
return numInvalidated;
729-
}
730-
731-
void PyMlirContext::clearOperation(MlirOperation op) {
732-
PyOperation *py_op;
733-
{
734-
nb::ft_lock_guard lock(liveOperationsMutex);
735-
auto it = liveOperations.find(op.ptr);
736-
if (it == liveOperations.end()) {
737-
return;
738-
}
739-
py_op = it->second.second;
740-
liveOperations.erase(it);
741-
}
742-
py_op->setInvalid();
743-
}
744-
745-
void PyMlirContext::clearOperationsInside(PyOperationBase &op) {
746-
typedef struct {
747-
PyOperation &rootOp;
748-
bool rootSeen;
749-
} callBackData;
750-
callBackData data{op.getOperation(), false};
751-
// Mark all ops below the op that the passmanager will be rooted
752-
// at (but not op itself - note the preorder) as invalid.
753-
MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op,
754-
void *userData) {
755-
callBackData *data = static_cast<callBackData *>(userData);
756-
if (LLVM_LIKELY(data->rootSeen))
757-
data->rootOp.getOperation().getContext()->clearOperation(op);
758-
else
759-
data->rootSeen = true;
760-
return MlirWalkResult::MlirWalkResultAdvance;
761-
};
762-
mlirOperationWalk(op.getOperation(), invalidatingCallback,
763-
static_cast<void *>(&data), MlirWalkPreOrder);
764-
}
765-
void PyMlirContext::clearOperationsInside(MlirOperation op) {
766-
PyOperationRef opRef = PyOperation::forOperation(getRef(), op);
767-
clearOperationsInside(opRef->getOperation());
768-
}
769-
770-
void PyMlirContext::clearOperationAndInside(PyOperationBase &op) {
771-
MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op,
772-
void *userData) {
773-
PyMlirContextRef &contextRef = *static_cast<PyMlirContextRef *>(userData);
774-
contextRef->clearOperation(op);
775-
return MlirWalkResult::MlirWalkResultAdvance;
776-
};
777-
mlirOperationWalk(op.getOperation(), invalidatingCallback,
778-
&op.getOperation().getContext(), MlirWalkPreOrder);
779-
}
780-
781-
size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
782-
783705
nb::object PyMlirContext::contextEnter(nb::object context) {
784706
return PyThreadContextEntry::pushContext(context);
785707
}
@@ -1151,38 +1073,20 @@ PyLocation &DefaultingPyLocation::resolve() {
11511073
PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module)
11521074
: BaseContextObject(std::move(contextRef)), module(module) {}
11531075

1154-
PyModule::~PyModule() {
1155-
nb::gil_scoped_acquire acquire;
1156-
auto &liveModules = getContext()->liveModules;
1157-
assert(liveModules.count(module.ptr) == 1 &&
1158-
"destroying module not in live map");
1159-
liveModules.erase(module.ptr);
1160-
mlirModuleDestroy(module);
1161-
}
1076+
PyModule::~PyModule() { mlirModuleDestroy(module); }
11621077

11631078
PyModuleRef PyModule::forModule(MlirModule module) {
11641079
MlirContext context = mlirModuleGetContext(module);
11651080
PyMlirContextRef contextRef = PyMlirContext::forContext(context);
11661081

1167-
nb::gil_scoped_acquire acquire;
1168-
auto &liveModules = contextRef->liveModules;
1169-
auto it = liveModules.find(module.ptr);
1170-
if (it == liveModules.end()) {
1171-
// Create.
1172-
PyModule *unownedModule = new PyModule(std::move(contextRef), module);
1173-
// Note that the default return value policy on cast is automatic_reference,
1174-
// which does not take ownership (delete will not be called).
1175-
// Just be explicit.
1176-
nb::object pyRef = nb::cast(unownedModule, nb::rv_policy::take_ownership);
1177-
unownedModule->handle = pyRef;
1178-
liveModules[module.ptr] =
1179-
std::make_pair(unownedModule->handle, unownedModule);
1180-
return PyModuleRef(unownedModule, std::move(pyRef));
1181-
}
1182-
// Use existing.
1183-
PyModule *existing = it->second.second;
1184-
nb::object pyRef = nb::borrow<nb::object>(it->second.first);
1185-
return PyModuleRef(existing, std::move(pyRef));
1082+
// Create.
1083+
PyModule *unownedModule = new PyModule(std::move(contextRef), module);
1084+
// Note that the default return value policy on cast is automatic_reference,
1085+
// which does not take ownership (delete will not be called).
1086+
// Just be explicit.
1087+
nb::object pyRef = nb::cast(unownedModule, nb::rv_policy::take_ownership);
1088+
unownedModule->handle = pyRef;
1089+
return PyModuleRef(unownedModule, std::move(pyRef));
11861090
}
11871091

11881092
nb::object PyModule::createFromCapsule(nb::object capsule) {
@@ -1207,16 +1111,8 @@ PyOperation::~PyOperation() {
12071111
// If the operation has already been invalidated there is nothing to do.
12081112
if (!valid)
12091113
return;
1210-
1211-
// Otherwise, invalidate the operation and remove it from live map when it is
1212-
// attached.
1213-
if (isAttached()) {
1214-
getContext()->clearOperation(*this);
1215-
} else {
1216-
// And destroy it when it is detached, i.e. owned by Python, in which case
1217-
// all nested operations must be invalidated at removed from the live map as
1218-
// well.
1219-
erase();
1114+
if (!isAttached()) {
1115+
mlirOperationDestroy(operation);
12201116
}
12211117
}
12221118

@@ -1246,41 +1142,22 @@ PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
12461142
if (parentKeepAlive) {
12471143
unownedOperation->parentKeepAlive = std::move(parentKeepAlive);
12481144
}
1249-
return unownedOperation;
1145+
return PyOperationRef(unownedOperation, std::move(pyRef));
12501146
}
12511147

12521148
PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef,
12531149
MlirOperation operation,
12541150
nb::object parentKeepAlive) {
1255-
nb::ft_lock_guard lock(contextRef->liveOperationsMutex);
1256-
auto &liveOperations = contextRef->liveOperations;
1257-
auto it = liveOperations.find(operation.ptr);
1258-
if (it == liveOperations.end()) {
1259-
// Create.
1260-
PyOperationRef result = createInstance(std::move(contextRef), operation,
1261-
std::move(parentKeepAlive));
1262-
liveOperations[operation.ptr] =
1263-
std::make_pair(result.getObject(), result.get());
1264-
return result;
1265-
}
1266-
// Use existing.
1267-
PyOperation *existing = it->second.second;
1268-
nb::object pyRef = nb::borrow<nb::object>(it->second.first);
1269-
return PyOperationRef(existing, std::move(pyRef));
1151+
// Create.
1152+
return createInstance(std::move(contextRef), operation,
1153+
std::move(parentKeepAlive));
12701154
}
12711155

12721156
PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef,
12731157
MlirOperation operation,
12741158
nb::object parentKeepAlive) {
1275-
nb::ft_lock_guard lock(contextRef->liveOperationsMutex);
1276-
auto &liveOperations = contextRef->liveOperations;
1277-
assert(liveOperations.count(operation.ptr) == 0 &&
1278-
"cannot create detached operation that already exists");
1279-
(void)liveOperations;
12801159
PyOperationRef created = createInstance(std::move(contextRef), operation,
12811160
std::move(parentKeepAlive));
1282-
liveOperations[operation.ptr] =
1283-
std::make_pair(created.getObject(), created.get());
12841161
created->attached = false;
12851162
return created;
12861163
}
@@ -1652,7 +1529,6 @@ nb::object PyOperation::createOpView() {
16521529

16531530
void PyOperation::erase() {
16541531
checkValid();
1655-
getContext()->clearOperationAndInside(*this);
16561532
mlirOperationDestroy(operation);
16571533
}
16581534

@@ -2494,7 +2370,6 @@ class PyBlockArgumentList
24942370
: public Sliceable<PyBlockArgumentList, PyBlockArgument> {
24952371
public:
24962372
static constexpr const char *pyClassName = "BlockArgumentList";
2497-
using SliceableT = Sliceable<PyBlockArgumentList, PyBlockArgument>;
24982373

24992374
PyBlockArgumentList(PyOperationRef operation, MlirBlock block,
25002375
intptr_t startIndex = 0, intptr_t length = -1,
@@ -3023,14 +2898,6 @@ void mlir::python::populateIRCore(nb::module_ &m) {
30232898
PyMlirContextRef ref = PyMlirContext::forContext(self.get());
30242899
return ref.releaseObject();
30252900
})
3026-
.def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount)
3027-
.def("_get_live_operation_objects",
3028-
&PyMlirContext::getLiveOperationObjects)
3029-
.def("_clear_live_operations", &PyMlirContext::clearLiveOperations)
3030-
.def("_clear_live_operations_inside",
3031-
nb::overload_cast<MlirOperation>(
3032-
&PyMlirContext::clearOperationsInside))
3033-
.def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
30342901
.def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyMlirContext::getCapsule)
30352902
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
30362903
.def("__enter__", &PyMlirContext::contextEnter)
@@ -3428,7 +3295,13 @@ void mlir::python::populateIRCore(nb::module_ &m) {
34283295
// Defer to the operation's __str__.
34293296
return self.attr("operation").attr("__str__")();
34303297
},
3431-
kOperationStrDunderDocstring);
3298+
kOperationStrDunderDocstring)
3299+
.def(
3300+
"__eq__",
3301+
[](PyModule &self, PyModule &other) {
3302+
return mlirModuleEqual(self.get(), other.get());
3303+
},
3304+
"other"_a);
34323305

34333306
//----------------------------------------------------------------------------
34343307
// Mapping of Operation.
@@ -3440,7 +3313,8 @@ void mlir::python::populateIRCore(nb::module_ &m) {
34403313
})
34413314
.def("__eq__",
34423315
[](PyOperationBase &self, PyOperationBase &other) {
3443-
return &self.getOperation() == &other.getOperation();
3316+
return mlirOperationEqual(self.getOperation().get(),
3317+
other.getOperation().get());
34443318
})
34453319
.def("__eq__",
34463320
[](PyOperationBase &self, nb::object other) { return false; })

mlir/lib/Bindings/Python/IRModule.h

Lines changed: 0 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -218,40 +218,6 @@ class PyMlirContext {
218218
/// Gets the count of live context objects. Used for testing.
219219
static size_t getLiveCount();
220220

221-
/// Get a list of Python objects which are still in the live context map.
222-
std::vector<PyOperation *> getLiveOperationObjects();
223-
224-
/// Gets the count of live operations associated with this context.
225-
/// Used for testing.
226-
size_t getLiveOperationCount();
227-
228-
/// Clears the live operations map, returning the number of entries which were
229-
/// invalidated. To be used as a safety mechanism so that API end-users can't
230-
/// corrupt by holding references they shouldn't have accessed in the first
231-
/// place.
232-
size_t clearLiveOperations();
233-
234-
/// Removes an operation from the live operations map and sets it invalid.
235-
/// This is useful for when some non-bindings code destroys the operation and
236-
/// the bindings need to made aware. For example, in the case when pass
237-
/// manager is run.
238-
///
239-
/// Note that this does *NOT* clear the nested operations.
240-
void clearOperation(MlirOperation op);
241-
242-
/// Clears all operations nested inside the given op using
243-
/// `clearOperation(MlirOperation)`.
244-
void clearOperationsInside(PyOperationBase &op);
245-
void clearOperationsInside(MlirOperation op);
246-
247-
/// Clears the operaiton _and_ all operations inside using
248-
/// `clearOperation(MlirOperation)`.
249-
void clearOperationAndInside(PyOperationBase &op);
250-
251-
/// Gets the count of live modules associated with this context.
252-
/// Used for testing.
253-
size_t getLiveModuleCount();
254-
255221
/// Enter and exit the context manager.
256222
static nanobind::object contextEnter(nanobind::object context);
257223
void contextExit(const nanobind::object &excType,
@@ -278,25 +244,6 @@ class PyMlirContext {
278244
static nanobind::ft_mutex live_contexts_mutex;
279245
static LiveContextMap &getLiveContexts();
280246

281-
// Interns all live modules associated with this context. Modules tracked
282-
// in this map are valid. When a module is invalidated, it is removed
283-
// from this map, and while it still exists as an instance, any
284-
// attempt to access it will raise an error.
285-
using LiveModuleMap =
286-
llvm::DenseMap<const void *, std::pair<nanobind::handle, PyModule *>>;
287-
LiveModuleMap liveModules;
288-
289-
// Interns all live operations associated with this context. Operations
290-
// tracked in this map are valid. When an operation is invalidated, it is
291-
// removed from this map, and while it still exists as an instance, any
292-
// attempt to access it will raise an error.
293-
using LiveOperationMap =
294-
llvm::DenseMap<void *, std::pair<nanobind::handle, PyOperation *>>;
295-
nanobind::ft_mutex liveOperationsMutex;
296-
297-
// Guarded by liveOperationsMutex in free-threading mode.
298-
LiveOperationMap liveOperations;
299-
300247
bool emitErrorDiagnostics = false;
301248

302249
MlirContext context;

mlir/lib/Bindings/Python/Pass.cpp

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -159,11 +159,7 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
159159
"ValueError if the pipeline can't be parsed.")
160160
.def(
161161
"run",
162-
[](PyPassManager &passManager, PyOperationBase &op,
163-
bool invalidateOps) {
164-
if (invalidateOps) {
165-
op.getOperation().getContext()->clearOperationsInside(op);
166-
}
162+
[](PyPassManager &passManager, PyOperationBase &op) {
167163
// Actually run the pass manager.
168164
PyMlirContext::ErrorCapture errors(op.getOperation().getContext());
169165
MlirLogicalResult status = mlirPassManagerRunOnOp(
@@ -172,7 +168,7 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
172168
throw MLIRError("Failure while executing pass pipeline",
173169
errors.take());
174170
},
175-
"operation"_a, "invalidate_ops"_a = true,
171+
"operation"_a,
176172
"Run the pass manager on the provided operation, raising an "
177173
"MLIRError on failure.")
178174
.def(

mlir/lib/Bindings/Python/TransformInterpreter.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,6 @@ static void populateTransformInterpreterSubmodule(nb::module_ &m) {
6767
// root. This is awkward, but we don't have access to PyMlirContext
6868
// object here otherwise.
6969
nb::object obj = nb::cast(payloadRoot);
70-
obj.attr("context").attr("_clear_live_operations_inside")(payloadRoot);
7170

7271
MlirLogicalResult result = mlirTransformApplyNamedSequence(
7372
payloadRoot, transformRoot, transformModule, options.options);

mlir/lib/CAPI/IR/IR.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -465,6 +465,10 @@ MlirModule mlirModuleFromOperation(MlirOperation op) {
465465
return wrap(dyn_cast<ModuleOp>(unwrap(op)));
466466
}
467467

468+
bool mlirModuleEqual(MlirModule mod, MlirModule other) {
469+
return unwrap(mod) == unwrap(other);
470+
}
471+
468472
//===----------------------------------------------------------------------===//
469473
// Operation state API.
470474
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)