3030 raise RuntimeError ("Error loading imports from extension module" ) from e
3131
3232
33- KERNEL_ATTRIBUTE_NAME = "gpu.kernel"
34- KNOWN_BLOCK_SIZE_ATTRIBUTE_NAME = "gpu.known_block_size"
35- KNOWN_GRID_SIZE_ATTRIBUTE_NAME = "gpu.known_grid_size"
36-
37- FUNCTION_TYPE_ATTRIBUTE_NAME = "function_type"
38- SYM_NAME_ATTRIBUTE_NAME = "sym_name"
39- ARGUMENT_ATTRIBUTE_NAME = "arg_attrs"
40- RESULT_ATTRIBUTE_NAME = "res_attrs"
41-
42-
4333@_ods_cext .register_operation (_Dialect , replace = True )
4434class GPUFuncOp (GPUFuncOp ):
4535 __doc__ = GPUFuncOp .__doc__
4636
37+ KERNEL_ATTR_NAME = "gpu.kernel"
38+ KNOWN_BLOCK_SIZE_ATTR_NAME = "known_block_size"
39+ KNOWN_GRID_SIZE_ATTR_NAME = "known_grid_size"
40+
41+ FUNCTION_TYPE_ATTR_NAME = "function_type"
42+ SYM_NAME_ATTR_NAME = "sym_name"
43+ ARGUMENT_ATTR_NAME = "arg_attrs"
44+ RESULT_ATTR_NAME = "res_attrs"
45+
4746 def __init__ (
4847 self ,
4948 function_type : Union [FunctionType , TypeAttr ],
5049 sym_name : Optional [Union [str , StringAttr ]] = None ,
5150 kernel : Optional [bool ] = None ,
5251 body_builder : Optional [Callable [[GPUFuncOp ], None ]] = None ,
53- known_block_size : Optional [Union [List [int ], DenseI32ArrayAttr ]] = None ,
54- known_grid_size : Optional [Union [List [int ], DenseI32ArrayAttr ]] = None ,
52+ known_block_size : Optional [Union [Sequence [int ], DenseI32ArrayAttr ]] = None ,
53+ known_grid_size : Optional [Union [Sequence [int ], DenseI32ArrayAttr ]] = None ,
5554 * args ,
5655 loc = None ,
5756 ip = None ,
@@ -75,56 +74,52 @@ def __init__(
7574 else function_type
7675 )
7776 super ().__init__ (function_type , * args , loc = loc , ip = ip , ** kwargs )
77+
78+ if isinstance (sym_name , str ):
79+ sym_name = StringAttr .get (str (sym_name ))
7880 if sym_name is not None :
79- self .attributes [SYM_NAME_ATTRIBUTE_NAME ] = StringAttr .get (str (sym_name ))
81+ self .attributes [self .SYM_NAME_ATTR_NAME ] = sym_name
82+
8083 if kernel :
81- self .attributes [KERNEL_ATTRIBUTE_NAME ] = UnitAttr .get ()
82- if body_builder is not None :
83- with InsertionPoint (self .add_entry_block ()):
84- body_builder (self )
84+ self .attributes [self .KERNEL_ATTR_NAME ] = UnitAttr .get ()
8585 if known_block_size is not None :
86- if isinstance (known_block_size , list ):
87- self .attributes [KNOWN_BLOCK_SIZE_ATTRIBUTE_NAME ] = (
86+ if isinstance (known_block_size , Sequence ):
87+ self .attributes [self . KNOWN_BLOCK_SIZE_ATTR_NAME ] = (
8888 DenseI32ArrayAttr .get (known_block_size )
8989 )
90+ elif isinstance (known_block_size , DenseI32ArrayAttr ):
91+ self .attributes [self .KNOWN_BLOCK_SIZE_ATTR_NAME ] = known_block_size
9092 else :
91- self .attributes [KNOWN_BLOCK_SIZE_ATTRIBUTE_NAME ] = known_block_size
93+ raise ValueError (
94+ "known_block_size must be a list of integers or a DenseI32ArrayAttr"
95+ )
96+
9297 if known_grid_size is not None :
93- if isinstance (known_grid_size , list ):
94- self .attributes [KNOWN_GRID_SIZE_ATTRIBUTE_NAME ] = DenseI32ArrayAttr .get (
98+ if isinstance (known_grid_size , Sequence ):
99+ self .attributes [self . KNOWN_GRID_SIZE_ATTR_NAME ] = DenseI32ArrayAttr .get (
95100 known_grid_size
96101 )
102+ elif isinstance (known_grid_size , DenseI32ArrayAttr ):
103+ self .attributes [self .KNOWN_GRID_SIZE_ATTR_NAME ] = known_grid_size
97104 else :
98- self .attributes [KNOWN_GRID_SIZE_ATTRIBUTE_NAME ] = known_grid_size
105+ raise ValueError (
106+ "known_grid_size must be a list of integers or a DenseI32ArrayAttr"
107+ )
99108
100- @property
101- def type (self ) -> FunctionType :
102- return FunctionType (
103- TypeAttr (self .attributes [FUNCTION_TYPE_ATTRIBUTE_NAME ]).value
104- )
109+ if body_builder is not None :
110+ with InsertionPoint (self .add_entry_block ()):
111+ body_builder (self )
105112
106113 @property
107114 def name (self ) -> StringAttr :
108- return StringAttr (self .attributes [SYM_NAME_ATTRIBUTE_NAME ])
115+ return StringAttr (self .attributes [self . SYM_NAME_ATTR_NAME ])
109116
110117 @property
111118 def is_kernel (self ) -> bool :
112- return KERNEL_ATTRIBUTE_NAME in self .attributes
113-
114- @property
115- def known_block_size (self ) -> Optional [DenseI32ArrayAttr ]:
116- if KNOWN_BLOCK_SIZE_ATTRIBUTE_NAME not in self .attributes :
117- return None
118- return self .attributes [KNOWN_BLOCK_SIZE_ATTRIBUTE_NAME ]
119-
120- @property
121- def known_grid_size (self ) -> Optional [DenseI32ArrayAttr ]:
122- if KNOWN_GRID_SIZE_ATTRIBUTE_NAME not in self .attributes :
123- return None
124- return self .attributes [KNOWN_GRID_SIZE_ATTRIBUTE_NAME ]
119+ return self .KERNEL_ATTR_NAME in self .attributes
125120
126121 def add_entry_block (self ) -> Block :
127- function_type = self .type
122+ function_type = self .function_type . value
128123 return self .body .blocks .append (
129124 * function_type .inputs ,
130125 arg_locs = [self .location for _ in function_type .inputs ],
@@ -136,34 +131,38 @@ def entry_block(self) -> Block:
136131
137132 @property
138133 def arguments (self ) -> Sequence [Type ]:
139- return self .type .inputs
134+ return self .function_type . value .inputs
140135
141136 @property
142137 def arg_attrs (self ):
143- if ARGUMENT_ATTRIBUTE_NAME not in self .attributes :
144- return ArrayAttr .get ([DictAttr .get ({}) for _ in self .type .inputs ])
145- return ArrayAttr (self .attributes [ARGUMENT_ATTRIBUTE_NAME ])
138+ if self .ARGUMENT_ATTR_NAME not in self .attributes :
139+ return ArrayAttr .get (
140+ [DictAttr .get ({}) for _ in self .function_type .value .inputs ]
141+ )
142+ return ArrayAttr (self .attributes [self .ARGUMENT_ATTR_NAME ])
146143
147144 @arg_attrs .setter
148- def arg_attrs (self , attribute : Union [ArrayAttr , list [Attribute ]]):
145+ def arg_attrs (self , attribute : Union [ArrayAttr , Sequence [Attribute ]]):
149146 if isinstance (attribute , ArrayAttr ):
150- self .attributes [ARGUMENT_ATTRIBUTE_NAME ] = attribute
147+ self .attributes [self . ARGUMENT_ATTR_NAME ] = attribute
151148 else :
152- self .attributes [ARGUMENT_ATTRIBUTE_NAME ] = ArrayAttr .get (
149+ self .attributes [self . ARGUMENT_ATTR_NAME ] = ArrayAttr .get (
153150 attribute , context = self .context
154151 )
155152
156153 @property
157154 def result_attrs (self ) -> Optional [ArrayAttr ]:
158- if RESULT_ATTRIBUTE_NAME not in self .attributes :
159- return ArrayAttr .get ([DictAttr .get ({}) for _ in self .type .results ])
160- return self .attributes [RESULT_ATTRIBUTE_NAME ]
155+ if self .RESULT_ATTR_NAME not in self .attributes :
156+ return ArrayAttr .get (
157+ [DictAttr .get ({}) for _ in self .function_type .value .results ]
158+ )
159+ return self .attributes [self .RESULT_ATTR_NAME ]
161160
162161 @result_attrs .setter
163- def result_attrs (self , attribute : Union [ArrayAttr , list [Attribute ]]):
162+ def result_attrs (self , attribute : Union [ArrayAttr , Sequence [Attribute ]]):
164163 if isinstance (attribute , ArrayAttr ):
165- self .attributes [RESULT_ATTRIBUTE_NAME ] = attribute
164+ self .attributes [self . RESULT_ATTR_NAME ] = attribute
166165 else :
167- self .attributes [RESULT_ATTRIBUTE_NAME ] = ArrayAttr .get (
166+ self .attributes [self . RESULT_ATTR_NAME ] = ArrayAttr .get (
168167 attribute , context = self .context
169168 )
0 commit comments