Skip to content

Commit 19ce117

Browse files
Apply suggestions
1 parent 09a604c commit 19ce117

File tree

2 files changed

+64
-60
lines changed

2 files changed

+64
-60
lines changed

mlir/python/mlir/dialects/gpu/__init__.py

Lines changed: 55 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -30,28 +30,27 @@
3030
raise RuntimeError("Error loading imports from extension module") from e
3131

3232

33-
KERNEL_ATTRIBUTE_NAME = "gpu.kernel"
34-
KNOWN_BLOCK_SIZE_ATTRIBUTE_NAME = "gpu.known_block_size"
35-
KNOWN_GRID_SIZE_ATTRIBUTE_NAME = "gpu.known_grid_size"
36-
37-
FUNCTION_TYPE_ATTRIBUTE_NAME = "function_type"
38-
SYM_NAME_ATTRIBUTE_NAME = "sym_name"
39-
ARGUMENT_ATTRIBUTE_NAME = "arg_attrs"
40-
RESULT_ATTRIBUTE_NAME = "res_attrs"
41-
42-
4333
@_ods_cext.register_operation(_Dialect, replace=True)
4434
class GPUFuncOp(GPUFuncOp):
4535
__doc__ = GPUFuncOp.__doc__
4636

37+
KERNEL_ATTR_NAME = "gpu.kernel"
38+
KNOWN_BLOCK_SIZE_ATTR_NAME = "known_block_size"
39+
KNOWN_GRID_SIZE_ATTR_NAME = "known_grid_size"
40+
41+
FUNCTION_TYPE_ATTR_NAME = "function_type"
42+
SYM_NAME_ATTR_NAME = "sym_name"
43+
ARGUMENT_ATTR_NAME = "arg_attrs"
44+
RESULT_ATTR_NAME = "res_attrs"
45+
4746
def __init__(
4847
self,
4948
function_type: Union[FunctionType, TypeAttr],
5049
sym_name: Optional[Union[str, StringAttr]] = None,
5150
kernel: Optional[bool] = None,
5251
body_builder: Optional[Callable[[GPUFuncOp], None]] = None,
53-
known_block_size: Optional[Union[List[int], DenseI32ArrayAttr]] = None,
54-
known_grid_size: Optional[Union[List[int], DenseI32ArrayAttr]] = None,
52+
known_block_size: Optional[Union[Sequence[int], DenseI32ArrayAttr]] = None,
53+
known_grid_size: Optional[Union[Sequence[int], DenseI32ArrayAttr]] = None,
5554
*args,
5655
loc=None,
5756
ip=None,
@@ -75,56 +74,52 @@ def __init__(
7574
else function_type
7675
)
7776
super().__init__(function_type, *args, loc=loc, ip=ip, **kwargs)
77+
78+
if isinstance(sym_name, str):
79+
sym_name = StringAttr.get(str(sym_name))
7880
if sym_name is not None:
79-
self.attributes[SYM_NAME_ATTRIBUTE_NAME] = StringAttr.get(str(sym_name))
81+
self.attributes[self.SYM_NAME_ATTR_NAME] = sym_name
82+
8083
if kernel:
81-
self.attributes[KERNEL_ATTRIBUTE_NAME] = UnitAttr.get()
82-
if body_builder is not None:
83-
with InsertionPoint(self.add_entry_block()):
84-
body_builder(self)
84+
self.attributes[self.KERNEL_ATTR_NAME] = UnitAttr.get()
8585
if known_block_size is not None:
86-
if isinstance(known_block_size, list):
87-
self.attributes[KNOWN_BLOCK_SIZE_ATTRIBUTE_NAME] = (
86+
if isinstance(known_block_size, Sequence):
87+
self.attributes[self.KNOWN_BLOCK_SIZE_ATTR_NAME] = (
8888
DenseI32ArrayAttr.get(known_block_size)
8989
)
90+
elif isinstance(known_block_size, DenseI32ArrayAttr):
91+
self.attributes[self.KNOWN_BLOCK_SIZE_ATTR_NAME] = known_block_size
9092
else:
91-
self.attributes[KNOWN_BLOCK_SIZE_ATTRIBUTE_NAME] = known_block_size
93+
raise ValueError(
94+
"known_block_size must be a list of integers or a DenseI32ArrayAttr"
95+
)
96+
9297
if known_grid_size is not None:
93-
if isinstance(known_grid_size, list):
94-
self.attributes[KNOWN_GRID_SIZE_ATTRIBUTE_NAME] = DenseI32ArrayAttr.get(
98+
if isinstance(known_grid_size, Sequence):
99+
self.attributes[self.KNOWN_GRID_SIZE_ATTR_NAME] = DenseI32ArrayAttr.get(
95100
known_grid_size
96101
)
102+
elif isinstance(known_grid_size, DenseI32ArrayAttr):
103+
self.attributes[self.KNOWN_GRID_SIZE_ATTR_NAME] = known_grid_size
97104
else:
98-
self.attributes[KNOWN_GRID_SIZE_ATTRIBUTE_NAME] = known_grid_size
105+
raise ValueError(
106+
"known_grid_size must be a list of integers or a DenseI32ArrayAttr"
107+
)
99108

100-
@property
101-
def type(self) -> FunctionType:
102-
return FunctionType(
103-
TypeAttr(self.attributes[FUNCTION_TYPE_ATTRIBUTE_NAME]).value
104-
)
109+
if body_builder is not None:
110+
with InsertionPoint(self.add_entry_block()):
111+
body_builder(self)
105112

106113
@property
107114
def name(self) -> StringAttr:
108-
return StringAttr(self.attributes[SYM_NAME_ATTRIBUTE_NAME])
115+
return StringAttr(self.attributes[self.SYM_NAME_ATTR_NAME])
109116

110117
@property
111118
def is_kernel(self) -> bool:
112-
return KERNEL_ATTRIBUTE_NAME in self.attributes
113-
114-
@property
115-
def known_block_size(self) -> Optional[DenseI32ArrayAttr]:
116-
if KNOWN_BLOCK_SIZE_ATTRIBUTE_NAME not in self.attributes:
117-
return None
118-
return self.attributes[KNOWN_BLOCK_SIZE_ATTRIBUTE_NAME]
119-
120-
@property
121-
def known_grid_size(self) -> Optional[DenseI32ArrayAttr]:
122-
if KNOWN_GRID_SIZE_ATTRIBUTE_NAME not in self.attributes:
123-
return None
124-
return self.attributes[KNOWN_GRID_SIZE_ATTRIBUTE_NAME]
119+
return self.KERNEL_ATTR_NAME in self.attributes
125120

126121
def add_entry_block(self) -> Block:
127-
function_type = self.type
122+
function_type = self.function_type.value
128123
return self.body.blocks.append(
129124
*function_type.inputs,
130125
arg_locs=[self.location for _ in function_type.inputs],
@@ -136,34 +131,38 @@ def entry_block(self) -> Block:
136131

137132
@property
138133
def arguments(self) -> Sequence[Type]:
139-
return self.type.inputs
134+
return self.function_type.value.inputs
140135

141136
@property
142137
def arg_attrs(self):
143-
if ARGUMENT_ATTRIBUTE_NAME not in self.attributes:
144-
return ArrayAttr.get([DictAttr.get({}) for _ in self.type.inputs])
145-
return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME])
138+
if self.ARGUMENT_ATTR_NAME not in self.attributes:
139+
return ArrayAttr.get(
140+
[DictAttr.get({}) for _ in self.function_type.value.inputs]
141+
)
142+
return ArrayAttr(self.attributes[self.ARGUMENT_ATTR_NAME])
146143

147144
@arg_attrs.setter
148-
def arg_attrs(self, attribute: Union[ArrayAttr, list[Attribute]]):
145+
def arg_attrs(self, attribute: Union[ArrayAttr, Sequence[Attribute]]):
149146
if isinstance(attribute, ArrayAttr):
150-
self.attributes[ARGUMENT_ATTRIBUTE_NAME] = attribute
147+
self.attributes[self.ARGUMENT_ATTR_NAME] = attribute
151148
else:
152-
self.attributes[ARGUMENT_ATTRIBUTE_NAME] = ArrayAttr.get(
149+
self.attributes[self.ARGUMENT_ATTR_NAME] = ArrayAttr.get(
153150
attribute, context=self.context
154151
)
155152

156153
@property
157154
def result_attrs(self) -> Optional[ArrayAttr]:
158-
if RESULT_ATTRIBUTE_NAME not in self.attributes:
159-
return ArrayAttr.get([DictAttr.get({}) for _ in self.type.results])
160-
return self.attributes[RESULT_ATTRIBUTE_NAME]
155+
if self.RESULT_ATTR_NAME not in self.attributes:
156+
return ArrayAttr.get(
157+
[DictAttr.get({}) for _ in self.function_type.value.results]
158+
)
159+
return self.attributes[self.RESULT_ATTR_NAME]
161160

162161
@result_attrs.setter
163-
def result_attrs(self, attribute: Union[ArrayAttr, list[Attribute]]):
162+
def result_attrs(self, attribute: Union[ArrayAttr, Sequence[Attribute]]):
164163
if isinstance(attribute, ArrayAttr):
165-
self.attributes[RESULT_ATTRIBUTE_NAME] = attribute
164+
self.attributes[self.RESULT_ATTR_NAME] = attribute
166165
else:
167-
self.attributes[RESULT_ATTRIBUTE_NAME] = ArrayAttr.get(
166+
self.attributes[self.RESULT_ATTR_NAME] = ArrayAttr.get(
168167
attribute, context=self.context
169168
)

mlir/test/python/dialects/gpu/dialect.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,8 @@ def builder(func: gpu.GPUFuncOp) -> None:
8686
func_type = ir.FunctionType.get(inputs=[], results=[])
8787
type_attr = TypeAttr.get(func_type)
8888
func = gpu.GPUFuncOp(type_attr, name)
89-
func.attributes[gpu.SYM_NAME_ATTRIBUTE_NAME] = name
90-
func.attributes[gpu.KERNEL_ATTRIBUTE_NAME] = UnitAttr.get()
89+
func.attributes["sym_name"] = name
90+
func.attributes["gpu.kernel"] = UnitAttr.get()
9191
block = func.body.blocks.append()
9292
with InsertionPoint(block):
9393
builder(func)
@@ -102,13 +102,18 @@ def builder(func: gpu.GPUFuncOp) -> None:
102102
)
103103

104104
assert func.name.value == "kernel1"
105+
assert func.function_type.value == func_type
105106
assert func.arg_attrs == ArrayAttr.get([])
106107
assert func.result_attrs == ArrayAttr.get([])
107108
assert func.arguments == []
108109
assert func.entry_block == func.body.blocks[0]
109110
assert func.is_kernel
110-
assert func.known_block_size == DenseI32ArrayAttr.get([1, 2, 3])
111-
assert func.known_grid_size == DenseI32ArrayAttr.get([4, 5, 6])
111+
assert func.known_block_size == DenseI32ArrayAttr.get(
112+
[1, 2, 3]
113+
), func.known_block_size
114+
assert func.known_grid_size == DenseI32ArrayAttr.get(
115+
[4, 5, 6]
116+
), func.known_grid_size
112117

113118
func = gpu.GPUFuncOp(
114119
func_type,

0 commit comments

Comments
 (0)