1
+ import copy
1
2
import inspect
2
3
import sys
3
4
from dataclasses import dataclass
4
5
from functools import update_wrapper
5
6
from typing import Optional , List , Union , TypeVar
6
7
7
- from ...ast .util import copy_func
8
+ from ...ast .util import copy_func , PyTypeVarObject
8
9
from ...meta import op_region_builder
10
+ from ... import types as T
9
11
from ...util import get_user_code_loc , make_maybe_no_args_decorator
10
12
from ....dialects ._ods_common import get_op_result_or_op_results
11
13
from ....dialects .func import *
@@ -105,17 +107,17 @@ def prep_func_types(sig, return_types):
105
107
return_types = [return_types ]
106
108
return_types = list (return_types )
107
109
assert all (
108
- isinstance (r , Type ) for r in return_types
109
- ), f"all return types must be mlir types { return_types = } "
110
+ isinstance (r , ( str , Type , TypeVar )) or isalambda ( r ) for r in return_types
111
+ ), f"all return types must be mlir types or strings or TypeVars or lambdas { return_types = } "
110
112
111
113
input_types = [
112
114
p .annotation
113
115
for p in sig .parameters .values ()
114
116
if not p .annotation is inspect .Signature .empty
115
117
]
116
118
assert all (
117
- isinstance (r , (str , Type )) or isalambda (r ) for r in input_types
118
- ), f"all input types must be mlir types { input_types = } "
119
+ isinstance (r , (str , Type , TypeVar )) or isalambda (r ) for r in input_types
120
+ ), f"all input types must be mlir types or strings or TypeVars or lambdas { input_types = } "
119
121
user_loc = get_user_code_loc ()
120
122
# If ir.Context is none (like for deferred func emit)
121
123
if user_loc is None :
@@ -205,13 +207,15 @@ def emit(self, *call_args, decl=False, force=False) -> FuncOp:
205
207
if self ._func_op is None or decl or force :
206
208
if len (call_args ) == 0 :
207
209
input_types = self .input_types [:]
208
- locals = {}
210
+ locals = {"T" : T }
209
211
if self .generics is not None :
210
212
for t in self .generics :
211
213
if not isinstance (t , ReifiedTypeParams ):
212
214
raise RuntimeError (f"{ t = } must reified" )
213
215
locals [t .name ] = t .val
214
216
for i , v in enumerate (input_types ):
217
+ if isinstance (v , TypeVar ):
218
+ v = v .__name__
215
219
if isinstance (v , str ):
216
220
input_types [i ] = Type (
217
221
eval (v , self .body_builder .__globals__ , locals )
@@ -274,12 +278,38 @@ def __getitem__(self, item):
274
278
# this also copies the function so that the original body_builder remains "generic" (via its closure)
275
279
body_builder = copy_func (self .body_builder )
276
280
reified_type_params = []
277
- for i , t in enumerate (self .generics ):
278
- if t .__bound__ is not None :
279
- r = ReifiedTypeParams (t .__name__ , t .__bound__ )
281
+ # dumb but whatever
282
+ already_reified_type_params = {}
283
+ generics = copy .deepcopy (self .generics )
284
+ for i , t in enumerate (generics ):
285
+ if sys .version_info >= (3 , 12 ):
286
+ type_var_bound = PyTypeVarObject .try_from (t ).bound
287
+ else :
288
+ type_var_bound = t .__bound__
289
+ if type_var_bound :
290
+ # before 3.12 typevar was just a python class
291
+ # https://github.com/python/cpython/blob/3.11/Lib/typing.py#L966
292
+ if sys .version_info < (3 , 12 ):
293
+ type_var_bound = lambda : type_var_bound
294
+ else :
295
+ type_var_bound = type_var_bound .contents .into_object ()
296
+ cvrs = inspect .getclosurevars (type_var_bound ).nonlocals
297
+ if len (cvrs ):
298
+ for k , v in cvrs .items ():
299
+ if not isinstance (v , TypeVar ):
300
+ continue
301
+ if k not in already_reified_type_params :
302
+ raise RuntimeError (
303
+ f"typevar { k } not reified prior to evaluating dependent typevar { t } "
304
+ )
305
+ cvrs [k ] = already_reified_type_params [k ]
306
+ type_var_bound = copy_func (type_var_bound , cvrs )
307
+ r = ReifiedTypeParams (t .__name__ , type_var_bound ())
280
308
else :
281
309
r = ReifiedTypeParams (t .__name__ , item [i ])
310
+
282
311
reified_type_params .append (r )
312
+ already_reified_type_params [r .name ] = r .val
283
313
284
314
if t .__name__ in body_builder .__globals__ :
285
315
body_builder .__globals__ [t .__name__ ] = r .val
@@ -290,8 +320,6 @@ def __getitem__(self, item):
290
320
), "typevars don't match"
291
321
body_builder .__closure__ [free_i ].cell_contents = r .val
292
322
293
- generics = reified_type_params
294
-
295
323
return FuncBase (
296
324
body_builder ,
297
325
self .func_op_ctor ,
@@ -302,7 +330,7 @@ def __getitem__(self, item):
302
330
arg_attrs = self .arg_attrs ,
303
331
res_attrs = self .res_attrs ,
304
332
func_attrs = self .func_attrs ,
305
- generics = generics ,
333
+ generics = reified_type_params ,
306
334
qualname = self .qualname ,
307
335
loc = self .loc ,
308
336
ip = self .ip ,
0 commit comments