Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 46 additions & 1 deletion mlir/lib/Bindings/Python/IRCore.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2730,14 +2730,59 @@ class PyOpAttributeMap {
operation->get(), toMlirStringRef(name)));
}

template <typename F>
auto forEachAttr(F fn) {
intptr_t n = mlirOperationGetNumAttributes(operation->get());
for (intptr_t i = 0; i < n; ++i) {
MlirNamedAttribute na = mlirOperationGetAttribute(operation->get(), i);
MlirStringRef name = mlirIdentifierStr(na.name);
fn(name, na.attribute);
}
}

static void bind(nb::module_ &m) {
nb::class_<PyOpAttributeMap>(m, "OpAttributeMap")
.def("__contains__", &PyOpAttributeMap::dunderContains)
.def("__len__", &PyOpAttributeMap::dunderLen)
.def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed)
.def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed)
.def("__setitem__", &PyOpAttributeMap::dunderSetItem)
.def("__delitem__", &PyOpAttributeMap::dunderDelItem);
.def("__delitem__", &PyOpAttributeMap::dunderDelItem)
.def("__iter__",
[](PyOpAttributeMap &self) {
nb::list keys;
self.forEachAttr([&](MlirStringRef name, MlirAttribute) {
keys.append(nb::str(name.data, name.length));
});
return nb::iter(keys);
})
.def("keys",
[](PyOpAttributeMap &self) {
nb::list out;
self.forEachAttr([&](MlirStringRef name, MlirAttribute) {
out.append(nb::str(name.data, name.length));
});
return out;
})
.def("values",
[](PyOpAttributeMap &self) {
nb::list out;
self.forEachAttr([&](MlirStringRef, MlirAttribute attr) {
out.append(PyAttribute(self.operation->getContext(), attr)
.maybeDownCast());
});
return out;
})
.def("items", [](PyOpAttributeMap &self) {
nb::list out;
self.forEachAttr([&](MlirStringRef name, MlirAttribute attr) {
out.append(
nb::make_tuple(nb::str(name.data, name.length),
PyAttribute(self.operation->getContext(), attr)
.maybeDownCast()));
});
return out;
});
}

private:
Expand Down
29 changes: 23 additions & 6 deletions mlir/test/python/ir/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,14 +569,31 @@ def testOperationAttributes():
# CHECK: Attribute value b'text'
print(f"Attribute value {sattr.value_bytes}")

# Python dict-style iteration
# We don't know in which order the attributes are stored.
# CHECK-DAG: NamedAttribute(dependent="text")
# CHECK-DAG: NamedAttribute(other.attribute=3.000000e+00 : f64)
# CHECK-DAG: NamedAttribute(some.attribute=1 : i8)
for attr in op.attributes:
print(str(attr))
# CHECK-DAG: dependent
# CHECK-DAG: other.attribute
# CHECK-DAG: some.attribute
for name in op.attributes:
print(name)

# Basic dict-like introspection
# CHECK: True
print("some.attribute" in op.attributes)
# CHECK: False
print("missing" in op.attributes)
# CHECK: Keys: ['dependent', 'other.attribute', 'some.attribute']
print("Keys:", sorted(op.attributes.keys()))
# CHECK: Values count 3
print("Values count", len(op.attributes.values()))
# CHECK: Items count 3
print("Items count", len(op.attributes.items()))

# Dict() conversion test
d = {k: v.value for k, v in dict(op.attributes).items()}
# CHECK: Dict mapping {'dependent': 'text', 'other.attribute': 3.0, 'some.attribute': 1}
print("Dict mapping", d)

# Check that exceptions are raised as expected.
try:
op.attributes["does_not_exist"]
except KeyError:
Expand Down