11import abc
2- import inspect
2+ import dataclasses
33import math
44import sys
5- import typing
65
6+ import mlir .runtime as rt
77from mlir import ir
88
99import numpy as np
1010
1111
1212class MlirType (abc .ABC ):
13- @classmethod
1413 @abc .abstractmethod
15- def _get_mlir_type (cls ) -> ir .Type : ...
14+ def _get_mlir_type (self ) -> ir .Type : ...
1615
1716
1817def _get_pointer_width () -> int :
@@ -22,106 +21,92 @@ def _get_pointer_width() -> int:
2221_PTR_WIDTH = _get_pointer_width ()
2322
2423
25- def _make_int_classes (namespace : dict [str , object ], bit_widths : typing .Iterable [int ]) -> None :
26- for bw in bit_widths :
27-
28- class SignedBW (SignedIntegerDType ):
29- np_dtype = getattr (np , f"int{ bw } " )
30- bit_width = bw
31-
32- @classmethod
33- def _get_mlir_type (cls ):
34- return ir .IntegerType .get_signless (cls .bit_width )
35-
36- SignedBW .__name__ = f"Int{ bw } "
37- SignedBW .__module__ = __name__
38-
39- class UnsignedBW (UnsignedIntegerDType ):
40- np_dtype = getattr (np , f"uint{ bw } " )
41- bit_width = bw
42-
43- @classmethod
44- def _get_mlir_type (cls ):
45- return ir .IntegerType .get_signless (cls .bit_width )
46-
47- UnsignedBW .__name__ = f"UInt{ bw } "
48- UnsignedBW .__module__ = __name__
49-
50- namespace [SignedBW .__name__ ] = SignedBW
51- namespace [UnsignedBW .__name__ ] = UnsignedBW
52-
53-
24+ @dataclasses .dataclass (eq = True , frozen = True , kw_only = True )
5425class DType (MlirType ):
55- np_dtype : np .dtype
5626 bit_width : int
5727
58- @classmethod
59- def to_ctype ( cls ):
60- return np .ctypeslib . as_ctypes_type ( cls . np_dtype )
61-
28+ @property
29+ @ abc . abstractmethod
30+ def np_dtype ( self ) -> np .dtype :
31+ raise NotImplementedError
6232
63- class FloatingDType (DType ): ...
33+ def to_ctype (self ):
34+ return rt .as_ctype (self .np_dtype )
6435
6536
66- class Float64 (FloatingDType ):
67- np_dtype = np .float64
68- bit_width = 64
37+ @dataclasses .dataclass (eq = True , frozen = True , kw_only = True )
38+ class IeeeRealFloatingDType (DType ):
39+ @property
40+ def np_dtype (self ) -> np .dtype :
41+ return np .dtype (getattr (np , f"float{ self .bit_width } " ))
6942
70- @classmethod
71- def _get_mlir_type (cls ):
72- return ir .F64Type .get ()
43+ def _get_mlir_type (self ) -> ir .Type :
44+ return getattr (ir , f"F{ self .bit_width } Type" ).get ()
7345
7446
75- class Float32 ( FloatingDType ):
76- np_dtype = np . float32
77- bit_width = 32
47+ float64 = IeeeRealFloatingDType ( bit_width = 64 )
48+ float32 = IeeeRealFloatingDType ( bit_width = 32 )
49+ float16 = IeeeRealFloatingDType ( bit_width = 16 )
7850
79- @classmethod
80- def _get_mlir_type (cls ):
81- return ir .F32Type .get ()
8251
52+ @dataclasses .dataclass (eq = True , frozen = True , kw_only = True )
53+ class IeeeComplexFloatingDType (DType ):
54+ @property
55+ def np_dtype (self ) -> np .dtype :
56+ return np .dtype (getattr (np , f"complex{ self .bit_width } " ))
8357
84- class Float16 (FloatingDType ):
85- np_dtype = np .float16
86- bit_width = 16
58+ def _get_mlir_type (self ) -> ir .Type :
59+ return ir .ComplexType .get (getattr (ir , f"F{ self .bit_width // 2 } Type" ).get ())
8760
88- @classmethod
89- def _get_mlir_type (cls ):
90- return ir .F16Type .get ()
9161
62+ complex64 = IeeeComplexFloatingDType (bit_width = 64 )
63+ complex128 = IeeeComplexFloatingDType (bit_width = 128 )
9264
93- class IntegerDType (DType ): ...
9465
66+ @dataclasses .dataclass (eq = True , frozen = True , kw_only = True )
67+ class IntegerDType (DType ):
68+ def _get_mlir_type (self ) -> ir .Type :
69+ return ir .IntegerType .get_signless (self .bit_width )
9570
96- class UnsignedIntegerDType (IntegerDType ): ...
9771
72+ @dataclasses .dataclass (eq = True , frozen = True , kw_only = True )
73+ class UnsignedIntegerDType (IntegerDType ):
74+ @property
75+ def np_dtype (self ) -> np .dtype :
76+ return np .dtype (getattr (np , f"uint{ self .bit_width } " ))
9877
99- class SignedIntegerDType (IntegerDType ): ...
10078
79+ int8 = UnsignedIntegerDType (bit_width = 8 )
80+ int16 = UnsignedIntegerDType (bit_width = 16 )
81+ int32 = UnsignedIntegerDType (bit_width = 32 )
82+ int64 = UnsignedIntegerDType (bit_width = 64 )
10183
102- _make_int_classes (locals (), [8 , 16 , 32 , 64 ])
10384
85+ @dataclasses .dataclass (eq = True , frozen = True , kw_only = True )
86+ class SignedIntegerDType (IntegerDType ):
87+ @property
88+ def np_dtype (self ) -> np .dtype :
89+ return np .dtype (getattr (np , f"int{ self .bit_width } " ))
10490
105- class Index (DType ):
106- np_dtype = np .intp
10791
108- @classmethod
109- def _get_mlir_type (cls ):
110- return ir .IndexType .get ()
92+ uint8 = SignedIntegerDType (bit_width = 8 )
93+ uint16 = SignedIntegerDType (bit_width = 16 )
94+ uint32 = SignedIntegerDType (bit_width = 32 )
95+ uint64 = SignedIntegerDType (bit_width = 64 )
11196
11297
113- IntP : type [ SignedIntegerDType ] = locals ()[f"Int { _PTR_WIDTH } " ]
114- UIntP : type [ UnsignedIntegerDType ] = locals ()[f"UInt { _PTR_WIDTH } " ]
98+ intp : SignedIntegerDType = locals ()[f"int { _PTR_WIDTH } " ]
99+ uintp : UnsignedIntegerDType = locals ()[f"uint { _PTR_WIDTH } " ]
115100
116101
117102def isdtype (dt , / ) -> bool :
118- return isinstance (dt , type ) and issubclass ( dt , DType ) and not inspect . isabstract ( dt )
103+ return isinstance (dt , DType )
119104
120105
121106NUMPY_DTYPE_MAP = {np .dtype (dt .np_dtype ): dt for dt in locals ().values () if isdtype (dt )}
122107
123108
124- def asdtype (dt , / ) -> type [ DType ] :
109+ def asdtype (dt , / ) -> DType :
125110 if isdtype (dt ):
126111 return dt
127112
0 commit comments