Skip to content

Commit 23610b5

Browse files
committed
port mlir_attribute_subclass
1 parent 97f20ab commit 23610b5

File tree

2 files changed

+24
-16
lines changed

2 files changed

+24
-16
lines changed

mlir/test/python/dialects/python_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -586,9 +586,9 @@ def testCustomAttribute():
586586
try:
587587
TestAttr(42)
588588
except TypeError as e:
589-
assert "Expected an MLIR object (got 42)" in str(e)
590-
except ValueError as e:
591-
assert "Cannot cast attribute to TestAttr (from 42)" in str(e)
589+
assert "__init__(): incompatible function arguments. The following argument types are supported" in str(e)
590+
assert "__init__(self, cast_from_attr: mlir._mlir_libs._mlir.ir.Attribute) -> None" in str(e)
591+
assert "Invoked with types: mlir._mlir_libs._mlirPythonTestNanobind.TestAttr, int" in str(e)
592592
else:
593593
raise
594594

mlir/test/python/lib/PythonTestModuleNanobind.cpp

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,26 @@ struct PyTestType : mlir::python::PyConcreteType<PyTestType> {
4545
}
4646
};
4747

48+
class PyTestAttr : public mlir::python::PyConcreteAttribute<PyTestAttr> {
49+
public:
50+
static constexpr IsAFunctionTy isaFunction =
51+
mlirAttributeIsAPythonTestTestAttribute;
52+
static constexpr const char *pyClassName = "TestAttr";
53+
using PyConcreteAttribute::PyConcreteAttribute;
54+
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
55+
mlirPythonTestTestAttributeGetTypeID;
56+
57+
static void bindDerived(ClassTy &c) {
58+
c.def_static(
59+
"get",
60+
[](mlir::python::DefaultingPyMlirContext context) {
61+
return PyTestAttr(context->getRef(), mlirPythonTestTestAttributeGet(
62+
context.get()->get()));
63+
},
64+
nb::arg("context").none() = nb::none());
65+
}
66+
};
67+
4868
NB_MODULE(_mlirPythonTestNanobind, m) {
4969
m.def(
5070
"register_python_test_dialect",
@@ -84,19 +104,7 @@ NB_MODULE(_mlirPythonTestNanobind, m) {
84104
nb::sig("def test_diagnostics_with_errors_and_notes(arg: " MAKE_MLIR_PYTHON_QUALNAME("ir.Context") ", /) -> None"));
85105
// clang-format on
86106

87-
mlir_attribute_subclass(m, "TestAttr",
88-
mlirAttributeIsAPythonTestTestAttribute,
89-
mlirPythonTestTestAttributeGetTypeID)
90-
.def_classmethod(
91-
"get",
92-
[](const nb::object &cls, MlirContext ctx) {
93-
return cls(mlirPythonTestTestAttributeGet(ctx));
94-
},
95-
// clang-format off
96-
nb::sig("def get(cls: object, context: " MAKE_MLIR_PYTHON_QUALNAME("ir.Context") " | None = None) -> object"),
97-
// clang-format on
98-
nb::arg("cls"), nb::arg("context").none() = nb::none());
99-
107+
PyTestAttr::bind(m);
100108
PyTestType::bind(m);
101109

102110
auto typeCls =

0 commit comments

Comments
 (0)