@@ -144,21 +144,26 @@ def __init__(
144
144
* ,
145
145
return_types = None ,
146
146
sym_visibility = None ,
147
+ sym_name = None ,
147
148
arg_attrs = None ,
148
149
res_attrs = None ,
149
150
func_attrs = None ,
151
+ function_type = None ,
150
152
generics : List [Union [TypeVar , ReifiedTypeParams ]] = None ,
151
153
qualname = None ,
152
154
loc = None ,
153
155
ip = None ,
154
156
):
155
157
assert inspect .isfunction (body_builder ), body_builder
156
158
assert inspect .isclass (func_op_ctor ), func_op_ctor
157
- assert inspect .isclass (return_op_ctor ), return_op_ctor
159
+ if return_op_ctor is not None :
160
+ assert inspect .isclass (return_op_ctor ), return_op_ctor
158
161
assert inspect .isclass (call_op_ctor ), call_op_ctor
159
162
160
163
self .body_builder = body_builder
161
- self .func_name = self .body_builder .__name__
164
+ if sym_name is None :
165
+ sym_name = self .body_builder .__name__
166
+ self .func_name = sym_name
162
167
self .func_op_ctor = func_op_ctor
163
168
self .return_op_ctor = return_op_ctor
164
169
self .call_op_ctor = call_op_ctor
@@ -175,6 +180,7 @@ def __init__(
175
180
self .func_attrs = func_attrs
176
181
if self .func_attrs is None :
177
182
self .func_attrs = {}
183
+ self .function_type = function_type
178
184
179
185
if return_types is None :
180
186
return_types = []
@@ -208,32 +214,37 @@ def __str__(self):
208
214
209
215
def emit (self , * call_args , decl = False , force = False ) -> FuncOp :
210
216
if self ._func_op is None or decl or force :
211
- if len (call_args ) == 0 :
212
- input_types = self .input_types [:]
213
- locals = {"T" : T }
214
- if self .generics is not None :
215
- for t in self .generics :
216
- if not isinstance (t , ReifiedTypeParams ):
217
- raise RuntimeError (f"{ t = } must reified" )
218
- locals [t .name ] = t .val
219
- for i , v in enumerate (input_types ):
220
- if isinstance (v , TypeVar ):
221
- v = v .__name__
222
- if isinstance (v , str ):
223
- input_types [i ] = Type (
224
- eval (v , self .body_builder .__globals__ , locals )
225
- )
226
- elif isalambda (v ):
227
- input_types [i ] = v ()
228
- else :
229
- input_types = [a .type for a in call_args ]
217
+ if self .function_type is None :
218
+ if len (call_args ) == 0 :
219
+ input_types = self .input_types [:]
220
+ locals = {"T" : T }
221
+ if self .generics is not None :
222
+ for t in self .generics :
223
+ if not isinstance (t , ReifiedTypeParams ):
224
+ raise RuntimeError (f"{ t = } must reified" )
225
+ locals [t .name ] = t .val
226
+ for i , v in enumerate (input_types ):
227
+ if isinstance (v , TypeVar ):
228
+ v = v .__name__
229
+ if isinstance (v , str ):
230
+ input_types [i ] = Type (
231
+ eval (v , self .body_builder .__globals__ , locals )
232
+ )
233
+ elif isalambda (v ):
234
+ input_types [i ] = v ()
235
+ else :
236
+ input_types = [a .type for a in call_args ]
230
237
231
- function_type = TypeAttr .get (
232
- FunctionType .get (
233
- inputs = input_types ,
234
- results = self .return_types ,
238
+ function_type = TypeAttr .get (
239
+ FunctionType .get (
240
+ inputs = input_types ,
241
+ results = self .return_types ,
242
+ )
235
243
)
236
- )
244
+ else :
245
+ input_types = self .function_type .inputs
246
+ function_type = TypeAttr .get (self .function_type )
247
+
237
248
self ._func_op = self .func_op_ctor (
238
249
self .func_name ,
239
250
function_type ,
@@ -264,10 +275,15 @@ def grab_results(*args):
264
275
return_types .append (results .type )
265
276
return results
266
277
267
- builder_wrapper (grab_results )
278
+ if self .function_type is None :
279
+ builder_wrapper (grab_results )
280
+ function_type = FunctionType .get (
281
+ inputs = input_types , results = return_types
282
+ )
283
+ self ._func_op .attributes ["function_type" ] = TypeAttr .get (function_type )
284
+ else :
285
+ builder_wrapper (self .body_builder )
268
286
269
- function_type = FunctionType .get (inputs = input_types , results = return_types )
270
- self ._func_op .attributes ["function_type" ] = TypeAttr .get (function_type )
271
287
return self ._func_op
272
288
273
289
def __call__ (self , * call_args ):
@@ -345,9 +361,11 @@ def func(
345
361
f ,
346
362
* ,
347
363
sym_visibility = None ,
364
+ sym_name = None ,
348
365
arg_attrs = None ,
349
366
res_attrs = None ,
350
367
func_attrs = None ,
368
+ function_type = None ,
351
369
emit = False ,
352
370
generics = None ,
353
371
loc = None ,
@@ -363,9 +381,11 @@ def func(
363
381
return_op_ctor = ReturnOp ,
364
382
call_op_ctor = CallOp .__base__ ,
365
383
sym_visibility = sym_visibility ,
384
+ sym_name = sym_name ,
366
385
arg_attrs = arg_attrs ,
367
386
res_attrs = res_attrs ,
368
387
func_attrs = func_attrs ,
388
+ function_type = function_type ,
369
389
generics = generics ,
370
390
loc = loc ,
371
391
ip = ip ,
0 commit comments