Skip to content

Commit 47d9d73

Browse files
[MLIR][Python] Add arg_attrs and res_attrs to gpu func (#168475)
I missed these attributes when I added the wrapper for GPUFuncOp in fbdd98f.
1 parent 4d09368 commit 47d9d73

File tree

2 files changed

+11
-7
lines changed

2 files changed

+11
-7
lines changed

mlir/python/mlir/dialects/gpu/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,13 @@ class GPUFuncOp(GPUFuncOp):
4949

5050
FUNCTION_TYPE_ATTR_NAME = "function_type"
5151
SYM_NAME_ATTR_NAME = "sym_name"
52-
ARGUMENT_ATTR_NAME = "arg_attrs"
53-
RESULT_ATTR_NAME = "res_attrs"
5452

5553
def __init__(
5654
self,
5755
function_type: Union[FunctionType, TypeAttr],
5856
sym_name: Optional[Union[str, StringAttr]] = None,
57+
arg_attrs: Optional[Sequence[dict]] = None,
58+
res_attrs: Optional[Sequence[dict]] = None,
5959
kernel: Optional[bool] = None,
6060
workgroup_attrib_attrs: Optional[Sequence[dict]] = None,
6161
private_attrib_attrs: Optional[Sequence[dict]] = None,
@@ -88,6 +88,8 @@ def __init__(
8888
)
8989
super().__init__(
9090
function_type,
91+
arg_attrs=arg_attrs,
92+
res_attrs=res_attrs,
9193
workgroup_attrib_attrs=workgroup_attrib_attrs,
9294
private_attrib_attrs=private_attrib_attrs,
9395
loc=loc,

mlir/test/python/dialects/gpu/dialect.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -133,9 +133,10 @@ def builder(func: gpu.GPUFuncOp) -> None:
133133
), func.known_grid_size
134134

135135
func = gpu.GPUFuncOp(
136-
func_type,
136+
ir.FunctionType.get(inputs=[T.index()], results=[]),
137137
sym_name="non_kernel_func",
138138
body_builder=builder,
139+
arg_attrs=[{"gpu.some_attribute": ir.StringAttr.get("foo")}],
139140
)
140141
assert not func.is_kernel
141142
assert func.known_block_size is None
@@ -154,10 +155,11 @@ def builder(func: gpu.GPUFuncOp) -> None:
154155
# CHECK: %[[VAL_0:.*]] = gpu.global_id x
155156
# CHECK: gpu.return
156157
# CHECK: }
157-
# CHECK: gpu.func @non_kernel_func() {
158-
# CHECK: %[[VAL_0:.*]] = gpu.global_id x
159-
# CHECK: gpu.return
160-
# CHECK: }
158+
# CHECK: gpu.func @non_kernel_func(
159+
# CHECK-SAME: %[[ARG0:.*]]: index {gpu.some_attribute = "foo"}) {
160+
# CHECK: %[[GLOBAL_ID_0:.*]] = gpu.global_id x
161+
# CHECK: gpu.return
162+
# CHECK: }
161163

162164

163165
# CHECK-LABEL: testGPULaunchFuncOp

0 commit comments

Comments
 (0)