88 IndexType ,
99 F16Type ,
1010 F32Type ,
11+ Type ,
1112)
1213
13- index = IndexType .get ()
14- bool_ = IntegerType .get_signless (1 )
15- i8 = IntegerType .get_signless (8 )
16- i16 = IntegerType .get_signless (16 )
17- i32 = IntegerType .get_signless (32 )
18- i64 = IntegerType .get_signless (64 )
19- f16 = F16Type .get ()
20- f32 = F32Type .get ()
21- f64 = F64Type .get ()
14+ index_t = IndexType .get ()
15+ bool_t = IntegerType .get_signless (1 )
16+ i8_t = IntegerType .get_signless (8 )
17+ i16_t = IntegerType .get_signless (16 )
18+ i32_t = IntegerType .get_signless (32 )
19+ i64_t = IntegerType .get_signless (64 )
20+ f16_t = F16Type .get ()
21+ f32_t = F32Type .get ()
22+ f64_t = F64Type .get ()
2223
2324NP_DTYPE_TO_MLIR_TYPE = lambda : {
24- np .int8 : i8 ,
25- np .int16 : i16 ,
26- np .int32 : i32 ,
27- np .int64 : i64 ,
25+ np .int8 : i8_t ,
26+ np .int16 : i16_t ,
27+ np .int32 : i32_t ,
28+ np .int64 : i64_t ,
2829 # this is techincally wrong i guess but numpy by default casts python scalars to this
2930 # so to support passing lists of ints we map this to index type
30- np .longlong : index ,
31- np .uintp : index ,
32- np .float16 : f16 ,
33- np .float32 : f32 ,
34- np .float64 : f64 ,
31+ np .longlong : index_t ,
32+ np .uintp : index_t ,
33+ np .float16 : f16_t ,
34+ np .float32 : f32_t ,
35+ np .float64 : f64_t ,
3536}
3637
3738MLIR_TYPE_TO_NP_DTYPE = lambda : {v : k for k , v in NP_DTYPE_TO_MLIR_TYPE ().items ()}
@@ -51,15 +52,29 @@ def infer_mlir_type(
5152 MLIR type corresponding to py_val.
5253 """
5354 if isinstance (py_val , bool ):
54- return bool_
55+ return bool_t
5556 elif isinstance (py_val , int ):
56- return i64
57+ return i64_t
5758 elif isinstance (py_val , float ):
58- return f64
59+ return f64_t
5960 elif isinstance (py_val , np .ndarray ):
6061 dtype = NP_DTYPE_TO_MLIR_TYPE ()[py_val .dtype .type ]
6162 return RankedTensorType .get (py_val .shape , dtype )
6263 else :
6364 raise NotImplementedError (
6465 f"Unsupported Python value { py_val = } with type { type (py_val )} "
6566 )
67+
68+
69+ def tensor_t (* args , element_type : Type = None ):
70+ if (element_type is None and not isinstance (args [- 1 ], Type )) or (
71+ isinstance (args [- 1 ], Type ) and element_type is not None
72+ ):
73+ raise ValueError (
74+ f"either element_type must be provided explicitly XOR last arg to tensor type constructor must be the element type"
75+ )
76+ if element_type is not None :
77+ type = element_type
78+ else :
79+ type = args [- 1 ]
80+ return RankedTensorType .get (args [:- 1 ], type )
0 commit comments