Skip to content

Commit 35cd291

Browse files
authored
[mlir][python] add dict-style to IR attributes (llvm#163200)
It makes sense that Attribute dicts/maps should behave like dicts in the Python bindings. Previously this was not the case.
1 parent 043cdf0 commit 35cd291

File tree

2 files changed

+78
-6
lines changed

2 files changed

+78
-6
lines changed

mlir/lib/Bindings/Python/IRCore.cpp

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2730,14 +2730,68 @@ class PyOpAttributeMap {
27302730
operation->get(), toMlirStringRef(name)));
27312731
}
27322732

2733+
static void
2734+
forEachAttr(MlirOperation op,
2735+
llvm::function_ref<void(MlirStringRef, MlirAttribute)> fn) {
2736+
intptr_t n = mlirOperationGetNumAttributes(op);
2737+
for (intptr_t i = 0; i < n; ++i) {
2738+
MlirNamedAttribute na = mlirOperationGetAttribute(op, i);
2739+
MlirStringRef name = mlirIdentifierStr(na.name);
2740+
fn(name, na.attribute);
2741+
}
2742+
}
2743+
27332744
static void bind(nb::module_ &m) {
27342745
nb::class_<PyOpAttributeMap>(m, "OpAttributeMap")
27352746
.def("__contains__", &PyOpAttributeMap::dunderContains)
27362747
.def("__len__", &PyOpAttributeMap::dunderLen)
27372748
.def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed)
27382749
.def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed)
27392750
.def("__setitem__", &PyOpAttributeMap::dunderSetItem)
2740-
.def("__delitem__", &PyOpAttributeMap::dunderDelItem);
2751+
.def("__delitem__", &PyOpAttributeMap::dunderDelItem)
2752+
.def("__iter__",
2753+
[](PyOpAttributeMap &self) {
2754+
nb::list keys;
2755+
PyOpAttributeMap::forEachAttr(
2756+
self.operation->get(),
2757+
[&](MlirStringRef name, MlirAttribute) {
2758+
keys.append(nb::str(name.data, name.length));
2759+
});
2760+
return nb::iter(keys);
2761+
})
2762+
.def("keys",
2763+
[](PyOpAttributeMap &self) {
2764+
nb::list out;
2765+
PyOpAttributeMap::forEachAttr(
2766+
self.operation->get(),
2767+
[&](MlirStringRef name, MlirAttribute) {
2768+
out.append(nb::str(name.data, name.length));
2769+
});
2770+
return out;
2771+
})
2772+
.def("values",
2773+
[](PyOpAttributeMap &self) {
2774+
nb::list out;
2775+
PyOpAttributeMap::forEachAttr(
2776+
self.operation->get(),
2777+
[&](MlirStringRef, MlirAttribute attr) {
2778+
out.append(PyAttribute(self.operation->getContext(), attr)
2779+
.maybeDownCast());
2780+
});
2781+
return out;
2782+
})
2783+
.def("items", [](PyOpAttributeMap &self) {
2784+
nb::list out;
2785+
PyOpAttributeMap::forEachAttr(
2786+
self.operation->get(),
2787+
[&](MlirStringRef name, MlirAttribute attr) {
2788+
out.append(nb::make_tuple(
2789+
nb::str(name.data, name.length),
2790+
PyAttribute(self.operation->getContext(), attr)
2791+
.maybeDownCast()));
2792+
});
2793+
return out;
2794+
});
27412795
}
27422796

27432797
private:

mlir/test/python/ir/operation.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -569,12 +569,30 @@ def testOperationAttributes():
569569
# CHECK: Attribute value b'text'
570570
print(f"Attribute value {sattr.value_bytes}")
571571

572+
# Python dict-style iteration
572573
# We don't know in which order the attributes are stored.
573-
# CHECK-DAG: NamedAttribute(dependent="text")
574-
# CHECK-DAG: NamedAttribute(other.attribute=3.000000e+00 : f64)
575-
# CHECK-DAG: NamedAttribute(some.attribute=1 : i8)
576-
for attr in op.attributes:
577-
print(str(attr))
574+
# CHECK-DAG: dependent
575+
# CHECK-DAG: other.attribute
576+
# CHECK-DAG: some.attribute
577+
for name in op.attributes:
578+
print(name)
579+
580+
# Basic dict-like introspection
581+
# CHECK: True
582+
print("some.attribute" in op.attributes)
583+
# CHECK: False
584+
print("missing" in op.attributes)
585+
# CHECK: Keys: ['dependent', 'other.attribute', 'some.attribute']
586+
print("Keys:", sorted(op.attributes.keys()))
587+
# CHECK: Values count 3
588+
print("Values count", len(op.attributes.values()))
589+
# CHECK: Items count 3
590+
print("Items count", len(op.attributes.items()))
591+
592+
# Dict() conversion test
593+
d = {k: v.value for k, v in dict(op.attributes).items()}
594+
# CHECK: Dict mapping {'dependent': 'text', 'other.attribute': 3.0, 'some.attribute': 1}
595+
print("Dict mapping", d)
578596

579597
# Check that exceptions are raised as expected.
580598
try:

0 commit comments

Comments
 (0)