From bad0d96c1f3be74b9fb7f9da9c7cd05c58b62375 Mon Sep 17 00:00:00 2001 From: pez Date: Thu, 21 Nov 2024 16:40:04 +0000 Subject: [PATCH 1/4] Add check to see if `FuncOp.arg_attrs` is set --- mlir/python/mlir/dialects/func.py | 4 ++++ mlir/test/python/dialects/func.py | 13 +++++++++++++ 2 files changed, 17 insertions(+) diff --git a/mlir/python/mlir/dialects/func.py b/mlir/python/mlir/dialects/func.py index 24fdcbcd85b29..211027d88051a 100644 --- a/mlir/python/mlir/dialects/func.py +++ b/mlir/python/mlir/dialects/func.py @@ -105,6 +105,10 @@ def add_entry_block(self, arg_locs: Optional[Sequence[Location]] = None): @property def arg_attrs(self): + if ARGUMENT_ATTRIBUTE_NAME not in self.attributes: + self.attributes[ARGUMENT_ATTRIBUTE_NAME] = ArrayAttr.get( + [DictAttr.get({}) for _ in self.type.inputs] + ) return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME]) @arg_attrs.setter diff --git a/mlir/test/python/dialects/func.py b/mlir/test/python/dialects/func.py index a2014c64d2fa5..bcfaace853bc6 100644 --- a/mlir/test/python/dialects/func.py +++ b/mlir/test/python/dialects/func.py @@ -104,3 +104,16 @@ def testFunctionCalls(): # CHECK: %1 = call @qux() : () -> f32 # CHECK: return # CHECK: } + + +# CHECK-LABEL: TEST: testFunctionArgAttrs +@constructAndPrintInModule +def testFunctionArgAttrs(): + foo = func.FuncOp("foo", ([("arg0", F32Type.get())], [])) + + assert len(foo.arg_attrs) == 1 + assert foo.arg_attrs[0] = ir.DictAttr.get({}) + + foo.arg_attrs = [DictAttr.get({"test.foo": StringAttr.get("bar")})] + + assert foo.arg_attrs[0]["test.foo"] == StringAttr.get("bar") From 049bacebd0c9dbe38e23568fe16113731093a81f Mon Sep 17 00:00:00 2001 From: pez Date: Thu, 21 Nov 2024 16:56:32 +0000 Subject: [PATCH 2/4] Linting fix --- mlir/test/python/dialects/func.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/test/python/dialects/func.py b/mlir/test/python/dialects/func.py index bcfaace853bc6..cc2d616c6407c 100644 --- a/mlir/test/python/dialects/func.py +++ b/mlir/test/python/dialects/func.py @@ -112,7 +112,7 @@ def testFunctionArgAttrs(): foo = func.FuncOp("foo", ([("arg0", F32Type.get())], [])) assert len(foo.arg_attrs) == 1 - assert foo.arg_attrs[0] = ir.DictAttr.get({}) + assert foo.arg_attrs[0] == DictAttr.get({}) foo.arg_attrs = [DictAttr.get({"test.foo": StringAttr.get("bar")})] From 2ae4815fe814948ad012aba1e3a7c9dd12e05909 Mon Sep 17 00:00:00 2001 From: pez Date: Fri, 22 Nov 2024 14:33:45 +0000 Subject: [PATCH 3/4] Add additional test case --- mlir/test/python/dialects/func.py | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/mlir/test/python/dialects/func.py b/mlir/test/python/dialects/func.py index cc2d616c6407c..6b3932ce64f13 100644 --- a/mlir/test/python/dialects/func.py +++ b/mlir/test/python/dialects/func.py @@ -109,11 +109,27 @@ def testFunctionCalls(): # CHECK-LABEL: TEST: testFunctionArgAttrs @constructAndPrintInModule def testFunctionArgAttrs(): - foo = func.FuncOp("foo", ([("arg0", F32Type.get())], [])) + foo = func.FuncOp("foo", ([F32Type.get()], [])) + foo.sym_visibility = StringAttr.get("private") + foo2 = func.FuncOp("foo2", ([F32Type.get(), F32Type.get()], [])) + foo2.sym_visibility = StringAttr.get("private") - assert len(foo.arg_attrs) == 1 - assert foo.arg_attrs[0] == DictAttr.get({}) + empty_attr = DictAttr.get({}) + test_attr = DictAttr.get({"test.foo": StringAttr.get("bar")}) + test_attr2 = DictAttr.get({"test.baz": StringAttr.get("qux")}) - foo.arg_attrs = [DictAttr.get({"test.foo": StringAttr.get("bar")})] + assert len(foo.arg_attrs) == 1 + assert foo.arg_attrs[0] == empty_attr + foo.arg_attrs = [test_attr] assert foo.arg_attrs[0]["test.foo"] == StringAttr.get("bar") + + assert len(foo2.arg_attrs) == 2 + assert foo2.arg_attrs == ArrayAttr.get([empty_attr, empty_attr]) + + foo2.arg_attrs = [empty_attr, test_attr2] + assert foo2.arg_attrs == ArrayAttr.get([empty_attr, test_attr2]) + + +# CHECK: func private @foo(f32 {test.foo = "bar"}) +# CHECK: func private @foo2(f32, f32 {test.baz = "qux"}) From c6afea5b86ccb35b10afdf5c03859378749f73c8 Mon Sep 17 00:00:00 2001 From: Perry Gibson Date: Tue, 26 Nov 2024 13:53:54 +0000 Subject: [PATCH 4/4] Return empty dict list, do not mutate IR --- mlir/python/mlir/dialects/func.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/mlir/python/mlir/dialects/func.py b/mlir/python/mlir/dialects/func.py index 211027d88051a..1898fc1565cd4 100644 --- a/mlir/python/mlir/dialects/func.py +++ b/mlir/python/mlir/dialects/func.py @@ -106,9 +106,7 @@ def add_entry_block(self, arg_locs: Optional[Sequence[Location]] = None): @property def arg_attrs(self): if ARGUMENT_ATTRIBUTE_NAME not in self.attributes: - self.attributes[ARGUMENT_ATTRIBUTE_NAME] = ArrayAttr.get( - [DictAttr.get({}) for _ in self.type.inputs] - ) + return ArrayAttr.get([DictAttr.get({}) for _ in self.type.inputs]) return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME]) @arg_attrs.setter