diff --git a/mlir/python/mlir/dialects/func.py b/mlir/python/mlir/dialects/func.py index 24fdcbcd85b29..1898fc1565cd4 100644 --- a/mlir/python/mlir/dialects/func.py +++ b/mlir/python/mlir/dialects/func.py @@ -105,6 +105,8 @@ def add_entry_block(self, arg_locs: Optional[Sequence[Location]] = None): @property def arg_attrs(self): + if ARGUMENT_ATTRIBUTE_NAME not in self.attributes: + return ArrayAttr.get([DictAttr.get({}) for _ in self.type.inputs]) return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME]) @arg_attrs.setter diff --git a/mlir/test/python/dialects/func.py b/mlir/test/python/dialects/func.py index a2014c64d2fa5..6b3932ce64f13 100644 --- a/mlir/test/python/dialects/func.py +++ b/mlir/test/python/dialects/func.py @@ -104,3 +104,32 @@ def testFunctionCalls(): # CHECK: %1 = call @qux() : () -> f32 # CHECK: return # CHECK: } + + +# CHECK-LABEL: TEST: testFunctionArgAttrs +@constructAndPrintInModule +def testFunctionArgAttrs(): + foo = func.FuncOp("foo", ([F32Type.get()], [])) + foo.sym_visibility = StringAttr.get("private") + foo2 = func.FuncOp("foo2", ([F32Type.get(), F32Type.get()], [])) + foo2.sym_visibility = StringAttr.get("private") + + empty_attr = DictAttr.get({}) + test_attr = DictAttr.get({"test.foo": StringAttr.get("bar")}) + test_attr2 = DictAttr.get({"test.baz": StringAttr.get("qux")}) + + assert len(foo.arg_attrs) == 1 + assert foo.arg_attrs[0] == empty_attr + + foo.arg_attrs = [test_attr] + assert foo.arg_attrs[0]["test.foo"] == StringAttr.get("bar") + + assert len(foo2.arg_attrs) == 2 + assert foo2.arg_attrs == ArrayAttr.get([empty_attr, empty_attr]) + + foo2.arg_attrs = [empty_attr, test_attr2] + assert foo2.arg_attrs == ArrayAttr.get([empty_attr, test_attr2]) + + +# CHECK: func private @foo(f32 {test.foo = "bar"}) +# CHECK: func private @foo2(f32, f32 {test.baz = "qux"})