1- import sys
21from functools import partial
32from typing import Union
43
2524 VectorType ,
2625)
2726
28- _index_t = lambda : IndexType .get ()
29- _bool_t = lambda : IntegerType .get_signless (1 )
27+ _index = lambda : IndexType .get ()
28+ _bool = lambda : IntegerType .get_signless (1 )
3029
31- _i8_t = lambda : IntegerType .get_signless (8 )
32- _i16_t = lambda : IntegerType .get_signless (16 )
33- _i32_t = lambda : IntegerType .get_signless (32 )
34- _i64_t = lambda : IntegerType .get_signless (64 )
30+ _i8 = lambda : IntegerType .get_signless (8 )
31+ _i16 = lambda : IntegerType .get_signless (16 )
32+ _i32 = lambda : IntegerType .get_signless (32 )
33+ _i64 = lambda : IntegerType .get_signless (64 )
3534
36- _si8_t = lambda : IntegerType .get_signed (8 )
37- _si16_t = lambda : IntegerType .get_signed (16 )
38- _si32_t = lambda : IntegerType .get_signed (32 )
39- _si64_t = lambda : IntegerType .get_signed (64 )
35+ _si8 = lambda : IntegerType .get_signed (8 )
36+ _si16 = lambda : IntegerType .get_signed (16 )
37+ _si32 = lambda : IntegerType .get_signed (32 )
38+ _si64 = lambda : IntegerType .get_signed (64 )
4039
41- _ui8_t = lambda : IntegerType .get_unsigned (8 )
42- _ui16_t = lambda : IntegerType .get_unsigned (16 )
43- _ui32_t = lambda : IntegerType .get_unsigned (32 )
44- _ui64_t = lambda : IntegerType .get_unsigned (64 )
40+ _ui8 = lambda : IntegerType .get_unsigned (8 )
41+ _ui16 = lambda : IntegerType .get_unsigned (16 )
42+ _ui32 = lambda : IntegerType .get_unsigned (32 )
43+ _ui64 = lambda : IntegerType .get_unsigned (64 )
4544
46- _f16_t = lambda : F16Type .get ()
47- _f32_t = lambda : F32Type .get ()
48- _f64_t = lambda : F64Type .get ()
49- _bf16_t = lambda : BF16Type .get ()
45+ _f16 = lambda : F16Type .get ()
46+ _f32 = lambda : F32Type .get ()
47+ _f64 = lambda : F64Type .get ()
48+ _bf16 = lambda : BF16Type .get ()
5049
51- _f8e5m2_t = lambda : Float8E5M2Type .get ()
52- _f8e4m3_t = lambda : Float8E4M3FNType .get ()
53- _f8e4m3b11fnuz_t = lambda : Float8E4M3B11FNUZType .get ()
50+ _f8e5m2 = lambda : Float8E5M2Type .get ()
51+ _f8e4m3 = lambda : Float8E4M3FNType .get ()
52+ _f8e4m3b11fnuz = lambda : Float8E4M3B11FNUZType .get ()
5453
55- _cmp16_t = lambda : ComplexType .get (_f16_t ())
56- _cmp32_t = lambda : ComplexType .get (_f32_t ())
57- _cmp64_t = lambda : ComplexType .get (_f64_t ())
54+ _cmp16 = lambda : ComplexType .get (_f16 ())
55+ _cmp32 = lambda : ComplexType .get (_f32 ())
56+ _cmp64 = lambda : ComplexType .get (_f64 ())
5857
59- _none_t = lambda : NoneType .get ()
58+ _none = lambda : NoneType .get ()
6059
61- opaque_t = lambda dialect_namespace , buffer : OpaqueType .get (dialect_namespace , buffer )
60+ opaque = lambda dialect_namespace , buffer : OpaqueType .get (dialect_namespace , buffer )
6261
6362
64- def _placeholder_opaque_t ():
65- return opaque_t ("scf" , "placeholder" )
63+ def placeholder_opaque ():
64+ return opaque ("scf" , "placeholder" )
6665
6766
6867_name_to_type = {
69- "index_t " : _index_t ,
70- "bool_t " : _bool_t ,
71- "i8_t " : _i8_t ,
72- "i16_t " : _i16_t ,
73- "i32_t " : _i32_t ,
74- "i64_t " : _i64_t ,
75- "si8_t " : _si8_t ,
76- "si16_t " : _si16_t ,
77- "si32_t " : _si32_t ,
78- "si64_t " : _si64_t ,
79- "ui8_t " : _ui8_t ,
80- "ui16_t " : _ui16_t ,
81- "ui32_t " : _ui32_t ,
82- "ui64_t " : _ui64_t ,
83- "f16_t " : _f16_t ,
84- "f32_t " : _f32_t ,
85- "f64_t " : _f64_t ,
86- "bf16_t " : _bf16_t ,
87- "f8e5m2_t " : _f8e5m2_t ,
88- "f8e4m3_t " : _f8e4m3_t ,
89- "f8e4m3b11fnuz_t " : _f8e4m3b11fnuz_t ,
90- "cmp16_t " : _cmp16_t ,
91- "cmp32_t " : _cmp32_t ,
92- "cmp64_t " : _cmp64_t ,
93- "none_t " : _none_t ,
68+ "index " : _index ,
69+ "bool " : _bool ,
70+ "i8 " : _i8 ,
71+ "i16 " : _i16 ,
72+ "i32 " : _i32 ,
73+ "i64 " : _i64 ,
74+ "si8 " : _si8 ,
75+ "si16 " : _si16 ,
76+ "si32 " : _si32 ,
77+ "si64 " : _si64 ,
78+ "ui8 " : _ui8 ,
79+ "ui16 " : _ui16 ,
80+ "ui32 " : _ui32 ,
81+ "ui64 " : _ui64 ,
82+ "f16 " : _f16 ,
83+ "f32 " : _f32 ,
84+ "f64 " : _f64 ,
85+ "bf16 " : _bf16 ,
86+ "f8e5m2 " : _f8e5m2 ,
87+ "f8e4m3 " : _f8e4m3 ,
88+ "f8e4m3b11fnuz " : _f8e4m3b11fnuz ,
89+ "cmp16 " : _cmp16 ,
90+ "cmp32 " : _cmp32 ,
91+ "cmp64 " : _cmp64 ,
92+ "none " : _none ,
9493}
9594
9695
@@ -102,19 +101,19 @@ def __getattr__(name):
102101
103102
104103_np_dtype_to_mlir_type_ctor = {
105- np .int8 : _i8_t ,
106- np .int16 : _i16_t ,
107- np .int32 : _i32_t ,
104+ np .int8 : _i8 ,
105+ np .int16 : _i16 ,
106+ np .int32 : _i32 ,
108107 # windows
109- np .intc : _i32_t ,
110- np .int64 : _i64_t ,
108+ np .intc : _i32 ,
109+ np .int64 : _i64 ,
111110 # is technically wrong i guess but numpy by default casts python scalars to this
112111 # so to support passing lists of ints we map to index type
113- np .longlong : _index_t ,
114- np .uintp : _index_t ,
115- np .float16 : _f16_t ,
116- np .float32 : _f32_t ,
117- np .float64 : _f64_t ,
112+ np .longlong : _index ,
113+ np .uintp : _index ,
114+ np .float16 : _f16 ,
115+ np .float32 : _f32 ,
116+ np .float64 : _f64 ,
118117}
119118
120119_mlir_type_ctor_to_np_dtype = lambda : {
@@ -146,16 +145,16 @@ def infer_mlir_type(
146145 MLIR type corresponding to py_val.
147146 """
148147 if isinstance (py_val , bool ):
149- return _bool_t ()
148+ return _bool ()
150149 elif isinstance (py_val , int ):
151150 if - (2 ** 31 ) <= py_val < 2 ** 31 :
152- return _i32_t ()
151+ return _i32 ()
153152 elif 2 ** 31 <= py_val < 2 ** 32 :
154- return _ui32_t ()
153+ return _ui32 ()
155154 elif - (2 ** 63 ) <= py_val < 2 ** 63 :
156- return _i64_t ()
155+ return _i64 ()
157156 elif 2 ** 63 <= py_val < 2 ** 64 :
158- return _ui64_t ()
157+ return _ui64 ()
159158 else :
160159 raise RuntimeError (f"Nonrepresentable integer { py_val } ." )
161160 elif isinstance (py_val , float ):
@@ -165,9 +164,9 @@ def infer_mlir_type(
165164 or py_val != py_val # NaN
166165 or np .finfo (np .float32 ).min <= abs (py_val ) <= np .finfo (np .float32 ).max
167166 ):
168- return _f32_t ()
167+ return _f32 ()
169168 else :
170- return _f64_t ()
169+ return _f64 ()
171170 elif isinstance (py_val , np .ndarray ):
172171 dtype = np_dtype_to_mlir_type (py_val .dtype .type )
173172 return RankedTensorType .get (py_val .shape , dtype )
@@ -177,9 +176,9 @@ def infer_mlir_type(
177176 )
178177
179178
180- def shaped_t (* args , element_type : Type = None , type_constructor = None ):
179+ def shaped (* args , element_type : Type = None , type_constructor = None ):
181180 if type_constructor is None :
182- raise ValueError ("shaped_t is an abstract base class - cannot be constructed" )
181+ raise ValueError ("shaped is an abstract base class - cannot be constructed" )
183182 if (element_type is None and args and not isinstance (args [- 1 ], Type )) or (
184183 args and isinstance (args [- 1 ], Type ) and element_type is not None
185184 ):
@@ -198,33 +197,33 @@ def shaped_t(*args, element_type: Type = None, type_constructor=None):
198197 return type_constructor (type )
199198
200199
201- def vector_t (* args , element_type : Type = None ):
202- return shaped_t (* args , element_type = element_type , type_constructor = VectorType .get )
200+ def vector (* args , element_type : Type = None ):
201+ return shaped (* args , element_type = element_type , type_constructor = VectorType .get )
203202
204203
205- def tensor_t (* args , element_type : Type = None ):
204+ def tensor (* args , element_type : Type = None ):
206205 if not len (args ) or len (args ) == 1 and isinstance (args [- 1 ], Type ):
207- return shaped_t (
206+ return shaped (
208207 * args , element_type = element_type , type_constructor = UnrankedTensorType .get
209208 )
210209 else :
211- return shaped_t (
210+ return shaped (
212211 * args , element_type = element_type , type_constructor = RankedTensorType .get
213212 )
214213
215214
216- def memref_t (* args , element_type : Type = None , memory_space : int = None ):
215+ def memref (* args , element_type : Type = None , memory_space : int = None ):
217216 if memory_space is None :
218217 memory_space = 0
219218 memory_space = Attribute .parse (str (memory_space ))
220219 if not len (args ) or len (args ) == 1 and isinstance (args [- 1 ], Type ):
221- return shaped_t (
220+ return shaped (
222221 * args ,
223222 element_type = element_type ,
224223 type_constructor = partial (UnrankedMemRefType .get , memory_space = memory_space ),
225224 )
226225 else :
227- return shaped_t (
226+ return shaped (
228227 * args ,
229228 element_type = element_type ,
230229 type_constructor = partial (MemRefType .get , memory_space = memory_space ),
0 commit comments