Skip to content

Commit 81cbd97

Browse files
authored
[MLIR][Python] remove nb::typed to fix bazel build (#160183)
#157930 broke bazel build (see #157930 (comment)) because bazel is stricter on implicit conversions (some difference in flags passed to clang). This PR fixes by moving/removing `nb::typed`. EDIT: and also the overlay...
1 parent c526c70 commit 81cbd97

File tree

7 files changed

+34
-47
lines changed

7 files changed

+34
-47
lines changed

mlir/lib/Bindings/Python/IRAttributes.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -485,7 +485,7 @@ class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
485485

486486
PyArrayAttributeIterator &dunderIter() { return *this; }
487487

488-
nb::typed<nb::object, PyAttribute> dunderNext() {
488+
nb::object dunderNext() {
489489
// TODO: Throw is an inefficient way to stop iteration.
490490
if (nextIndex >= mlirArrayAttrGetNumElements(attr.get()))
491491
throw nb::stop_iteration();

mlir/lib/Bindings/Python/IRCore.cpp

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -513,7 +513,7 @@ class PyOperationIterator {
513513

514514
PyOperationIterator &dunderIter() { return *this; }
515515

516-
nb::typed<nb::object, PyOpView> dunderNext() {
516+
nb::object dunderNext() {
517517
parentOperation->checkValid();
518518
if (mlirOperationIsNull(next)) {
519519
throw nb::stop_iteration();
@@ -562,7 +562,7 @@ class PyOperationList {
562562
return count;
563563
}
564564

565-
nb::typed<nb::object, PyOpView> dunderGetItem(intptr_t index) {
565+
nb::object dunderGetItem(intptr_t index) {
566566
parentOperation->checkValid();
567567
if (index < 0) {
568568
index += dunderLen();
@@ -1534,7 +1534,7 @@ nb::object PyOperation::create(std::string_view name,
15341534
return created.getObject();
15351535
}
15361536

1537-
nb::typed<nb::object, PyOpView> PyOperation::clone(const nb::object &maybeIp) {
1537+
nb::object PyOperation::clone(const nb::object &maybeIp) {
15381538
MlirOperation clonedOperation = mlirOperationClone(operation);
15391539
PyOperationRef cloned =
15401540
PyOperation::createDetached(getContext(), clonedOperation);
@@ -1543,7 +1543,7 @@ nb::typed<nb::object, PyOpView> PyOperation::clone(const nb::object &maybeIp) {
15431543
return cloned->createOpView();
15441544
}
15451545

1546-
nb::typed<nb::object, PyOpView> PyOperation::createOpView() {
1546+
nb::object PyOperation::createOpView() {
15471547
checkValid();
15481548
MlirIdentifier ident = mlirOperationGetName(get());
15491549
MlirStringRef identStr = mlirIdentifierStr(ident);
@@ -1638,9 +1638,9 @@ class PyOpResult : public PyConcreteValue<PyOpResult> {
16381638

16391639
/// Returns the list of types of the values held by container.
16401640
template <typename Container>
1641-
static std::vector<nb::typed<nb::object, PyType>>
1642-
getValueTypes(Container &container, PyMlirContextRef &context) {
1643-
std::vector<nb::typed<nb::object, PyType>> result;
1641+
static std::vector<nb::object> getValueTypes(Container &container,
1642+
PyMlirContextRef &context) {
1643+
std::vector<nb::object> result;
16441644
result.reserve(container.size());
16451645
for (int i = 0, e = container.size(); i < e; ++i) {
16461646
result.push_back(PyType(context->getRef(),
@@ -2133,7 +2133,7 @@ PyAttribute PyAttribute::createFromCapsule(nb::object capsule) {
21332133
PyMlirContext::forContext(mlirAttributeGetContext(rawAttr)), rawAttr);
21342134
}
21352135

2136-
nb::typed<nb::object, PyAttribute> PyAttribute::maybeDownCast() {
2136+
nb::object PyAttribute::maybeDownCast() {
21372137
MlirTypeID mlirTypeID = mlirAttributeGetTypeID(this->get());
21382138
assert(!mlirTypeIDIsNull(mlirTypeID) &&
21392139
"mlirTypeID was expected to be non-null.");
@@ -2179,7 +2179,7 @@ PyType PyType::createFromCapsule(nb::object capsule) {
21792179
rawType);
21802180
}
21812181

2182-
nb::typed<nb::object, PyType> PyType::maybeDownCast() {
2182+
nb::object PyType::maybeDownCast() {
21832183
MlirTypeID mlirTypeID = mlirTypeGetTypeID(this->get());
21842184
assert(!mlirTypeIDIsNull(mlirTypeID) &&
21852185
"mlirTypeID was expected to be non-null.");
@@ -2219,7 +2219,7 @@ nb::object PyValue::getCapsule() {
22192219
return nb::steal<nb::object>(mlirPythonValueToCapsule(get()));
22202220
}
22212221

2222-
nanobind::typed<nanobind::object, PyValue> PyValue::maybeDownCast() {
2222+
nb::object PyValue::maybeDownCast() {
22232223
MlirType type = mlirValueGetType(get());
22242224
MlirTypeID mlirTypeID = mlirTypeGetTypeID(type);
22252225
assert(!mlirTypeIDIsNull(mlirTypeID) &&
@@ -2263,8 +2263,7 @@ PySymbolTable::PySymbolTable(PyOperationBase &operation)
22632263
}
22642264
}
22652265

2266-
nb::typed<nb::object, PyOpView>
2267-
PySymbolTable::dunderGetItem(const std::string &name) {
2266+
nb::object PySymbolTable::dunderGetItem(const std::string &name) {
22682267
operation->checkValid();
22692268
MlirOperation symbol = mlirSymbolTableLookup(
22702269
symbolTable, mlirStringRefCreate(name.data(), name.length()));
@@ -2678,8 +2677,7 @@ class PyOpAttributeMap {
26782677
PyOpAttributeMap(PyOperationRef operation)
26792678
: operation(std::move(operation)) {}
26802679

2681-
nb::typed<nb::object, PyAttribute>
2682-
dunderGetItemNamed(const std::string &name) {
2680+
nb::object dunderGetItemNamed(const std::string &name) {
26832681
MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(),
26842682
toMlirStringRef(name));
26852683
if (mlirAttributeIsNull(attr)) {

mlir/lib/Bindings/Python/IRInterfaces.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ class PyConcreteOpInterface {
223223
/// Returns the opview of the operation instance from which this object was
224224
/// constructed. Throws a type error if this object was constructed form a
225225
/// subclass of OpView.
226-
nb::typed<nb::object, PyOpView> getOpView() {
226+
nb::object getOpView() {
227227
if (operation == nullptr) {
228228
throw nb::type_error("Cannot get an opview from a static interface");
229229
}

mlir/lib/Bindings/Python/IRModule.h

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ class PyObjectRef {
7676
/// Releases the object held by this instance, returning it.
7777
/// This is the proper thing to return from a function that wants to return
7878
/// the reference. Note that this does not work from initializers.
79-
nanobind::typed<nanobind::object, T> releaseObject() {
79+
nanobind::object releaseObject() {
8080
assert(referrent && object);
8181
referrent = nullptr;
8282
auto stolen = std::move(object);
@@ -88,12 +88,14 @@ class PyObjectRef {
8888
assert(referrent && object);
8989
return referrent;
9090
}
91-
nanobind::typed<nanobind::object, T> getObject() {
91+
nanobind::object getObject() {
9292
assert(referrent && object);
9393
return object;
9494
}
9595
operator bool() const { return referrent && object; }
9696

97+
using NBTypedT = nanobind::typed<nanobind::object, T>;
98+
9799
private:
98100
T *referrent;
99101
nanobind::object object;
@@ -680,7 +682,7 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
680682
PyLocation &location, const nanobind::object &ip, bool inferType);
681683

682684
/// Creates an OpView suitable for this operation.
683-
nanobind::typed<nanobind::object, PyOpView> createOpView();
685+
nanobind::object createOpView();
684686

685687
/// Erases the underlying MlirOperation, removes its pointer from the
686688
/// parent context's live operations map, and sets the valid bit false.
@@ -690,7 +692,7 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
690692
void setInvalid() { valid = false; }
691693

692694
/// Clones this operation.
693-
nanobind::typed<nanobind::object, PyOpView> clone(const nanobind::object &ip);
695+
nanobind::object clone(const nanobind::object &ip);
694696

695697
PyOperation(PyMlirContextRef contextRef, MlirOperation operation);
696698

@@ -890,7 +892,7 @@ class PyType : public BaseContextObject {
890892
/// is taken by calling this function.
891893
static PyType createFromCapsule(nanobind::object capsule);
892894

893-
nanobind::typed<nanobind::object, PyType> maybeDownCast();
895+
nanobind::object maybeDownCast();
894896

895897
private:
896898
MlirType type;
@@ -1020,7 +1022,7 @@ class PyAttribute : public BaseContextObject {
10201022
/// is taken by calling this function.
10211023
static PyAttribute createFromCapsule(nanobind::object capsule);
10221024

1023-
nanobind::typed<nanobind::object, PyAttribute> maybeDownCast();
1025+
nanobind::object maybeDownCast();
10241026

10251027
private:
10261028
MlirAttribute attr;
@@ -1178,7 +1180,7 @@ class PyValue {
11781180
/// Gets a capsule wrapping the void* within the MlirValue.
11791181
nanobind::object getCapsule();
11801182

1181-
nanobind::typed<nanobind::object, PyValue> maybeDownCast();
1183+
nanobind::object maybeDownCast();
11821184

11831185
/// Creates a PyValue from the MlirValue wrapped by a capsule. Ownership of
11841186
/// the underlying MlirValue is still tied to the owning operation.
@@ -1269,8 +1271,7 @@ class PySymbolTable {
12691271

12701272
/// Returns the symbol (opview) with the given name, throws if there is no
12711273
/// such symbol in the table.
1272-
nanobind::typed<nanobind::object, PyOpView>
1273-
dunderGetItem(const std::string &name);
1274+
nanobind::object dunderGetItem(const std::string &name);
12741275

12751276
/// Removes the given operation from the symbol table and erases it.
12761277
void erase(PyOperationBase &symbol);

mlir/lib/Bindings/Python/IRTypes.cpp

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -731,7 +731,8 @@ class PyRankedTensorType
731731
MlirAttribute encoding = mlirRankedTensorTypeGetEncoding(self.get());
732732
if (mlirAttributeIsNull(encoding))
733733
return std::nullopt;
734-
return PyAttribute(self.getContext(), encoding).maybeDownCast();
734+
return nb::cast<nb::typed<nb::object, PyAttribute>>(
735+
PyAttribute(self.getContext(), encoding).maybeDownCast());
735736
});
736737
}
737738
};
@@ -793,9 +794,9 @@ class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
793794
.def_prop_ro(
794795
"layout",
795796
[](PyMemRefType &self) -> nb::typed<nb::object, PyAttribute> {
796-
return PyAttribute(self.getContext(),
797-
mlirMemRefTypeGetLayout(self))
798-
.maybeDownCast();
797+
return nb::cast<nb::typed<nb::object, PyAttribute>>(
798+
PyAttribute(self.getContext(), mlirMemRefTypeGetLayout(self))
799+
.maybeDownCast());
799800
},
800801
"The layout of the MemRef type.")
801802
.def(
@@ -824,7 +825,8 @@ class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
824825
MlirAttribute a = mlirMemRefTypeGetMemorySpace(self);
825826
if (mlirAttributeIsNull(a))
826827
return std::nullopt;
827-
return PyAttribute(self.getContext(), a).maybeDownCast();
828+
return nb::cast<nb::typed<nb::object, PyAttribute>>(
829+
PyAttribute(self.getContext(), a).maybeDownCast());
828830
},
829831
"Returns the memory space of the given MemRef type.");
830832
}
@@ -865,7 +867,8 @@ class PyUnrankedMemRefType
865867
MlirAttribute a = mlirUnrankedMemrefGetMemorySpace(self);
866868
if (mlirAttributeIsNull(a))
867869
return std::nullopt;
868-
return PyAttribute(self.getContext(), a).maybeDownCast();
870+
return nb::cast<nb::typed<nb::object, PyAttribute>>(
871+
PyAttribute(self.getContext(), a).maybeDownCast());
869872
},
870873
"Returns the memory space of the given Unranked MemRef type.");
871874
}

mlir/lib/Bindings/Python/NanobindUtils.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ class Sliceable {
276276
/// Returns the element at the given slice index. Supports negative indices
277277
/// by taking elements in inverse order. Returns a nullptr object if out
278278
/// of bounds.
279-
nanobind::typed<nanobind::object, ElementTy> getItem(intptr_t index) {
279+
nanobind::object getItem(intptr_t index) {
280280
// Negative indices mean we count from the end.
281281
index = wrapIndex(index);
282282
if (index < 0) {

utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -53,14 +53,6 @@ filegroup(
5353
]),
5454
)
5555

56-
filegroup(
57-
name = "IRPyIFiles",
58-
srcs = [
59-
"mlir/_mlir_libs/_mlir/__init__.pyi",
60-
"mlir/_mlir_libs/_mlir/ir.pyi",
61-
],
62-
)
63-
6456
filegroup(
6557
name = "MlirLibsPyFiles",
6658
srcs = [
@@ -75,13 +67,6 @@ filegroup(
7567
],
7668
)
7769

78-
filegroup(
79-
name = "PassManagerPyIFiles",
80-
srcs = [
81-
"mlir/_mlir_libs/_mlir/passmanager.pyi",
82-
],
83-
)
84-
8570
filegroup(
8671
name = "RewritePyFiles",
8772
srcs = [

0 commit comments

Comments
 (0)