Skip to content

Commit 2112e20

Browse files
committed
rename types (no _t)
1 parent 9c4f3cb commit 2112e20

File tree

11 files changed

+160
-161
lines changed

11 files changed

+160
-161
lines changed

mlir_utils/dialects/ext/scf.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def yield_(*args):
126126

127127
assert len(results) == len(unpacked_args), f"{results=}, {unpacked_args=}"
128128
for i, r in enumerate(results):
129-
if r.type == T._placeholder_opaque_t():
129+
if r.type == T.placeholder_opaque():
130130
r.set_type(unpacked_args[i].type)
131131

132132
results = maybe_cast(results)
@@ -295,7 +295,7 @@ def visit_If(self, updated_node: ast.If) -> ast.With | list[ast.With, ast.With]:
295295

296296
test = updated_node.test
297297
results = [
298-
ast_call(T._placeholder_opaque_t.__name__)
298+
ast_call(T.placeholder_opaque.__name__)
299299
for _ in range(len(last_statement.value.args))
300300
]
301301
results = ast.fix_missing_locations(
@@ -358,7 +358,7 @@ def patch_bytecode(self, code: ConcreteBytecode, f):
358358
f.__globals__[yield_.__name__] = yield_
359359
f.__globals__[if_ctx_manager.__name__] = if_ctx_manager
360360
f.__globals__[else_ctx_manager.__name__] = else_ctx_manager
361-
f.__globals__[T._placeholder_opaque_t.__name__] = T._placeholder_opaque_t
361+
f.__globals__[T.placeholder_opaque.__name__] = T.placeholder_opaque
362362
return code
363363

364364

mlir_utils/dialects/ext/tensor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@
1414
ShapedType,
1515
)
1616

17+
import mlir_utils.types as T
1718
from mlir_utils.dialects import tensor
1819
from mlir_utils.dialects.ext.arith import ArithValue, Scalar, constant
19-
from mlir_utils.types import tensor_t
2020
from mlir_utils.util import (
2121
get_result_or_results,
2222
maybe_cast,
@@ -52,7 +52,7 @@ def extract_slice(
5252
assert offsets or static_offsets and bool(offsets) != bool(static_offsets)
5353
assert strides or static_strides and bool(strides) != bool(static_strides)
5454
sizes = []
55-
result = tensor_t(*static_sizes, source.dtype)
55+
result = T.tensor(*static_sizes, source.dtype)
5656
return tensor.extract_slice(
5757
result,
5858
source,

mlir_utils/types.py

Lines changed: 80 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import sys
21
from functools import partial
32
from typing import Union
43

@@ -25,72 +24,72 @@
2524
VectorType,
2625
)
2726

28-
_index_t = lambda: IndexType.get()
29-
_bool_t = lambda: IntegerType.get_signless(1)
27+
_index = lambda: IndexType.get()
28+
_bool = lambda: IntegerType.get_signless(1)
3029

31-
_i8_t = lambda: IntegerType.get_signless(8)
32-
_i16_t = lambda: IntegerType.get_signless(16)
33-
_i32_t = lambda: IntegerType.get_signless(32)
34-
_i64_t = lambda: IntegerType.get_signless(64)
30+
_i8 = lambda: IntegerType.get_signless(8)
31+
_i16 = lambda: IntegerType.get_signless(16)
32+
_i32 = lambda: IntegerType.get_signless(32)
33+
_i64 = lambda: IntegerType.get_signless(64)
3534

36-
_si8_t = lambda: IntegerType.get_signed(8)
37-
_si16_t = lambda: IntegerType.get_signed(16)
38-
_si32_t = lambda: IntegerType.get_signed(32)
39-
_si64_t = lambda: IntegerType.get_signed(64)
35+
_si8 = lambda: IntegerType.get_signed(8)
36+
_si16 = lambda: IntegerType.get_signed(16)
37+
_si32 = lambda: IntegerType.get_signed(32)
38+
_si64 = lambda: IntegerType.get_signed(64)
4039

41-
_ui8_t = lambda: IntegerType.get_unsigned(8)
42-
_ui16_t = lambda: IntegerType.get_unsigned(16)
43-
_ui32_t = lambda: IntegerType.get_unsigned(32)
44-
_ui64_t = lambda: IntegerType.get_unsigned(64)
40+
_ui8 = lambda: IntegerType.get_unsigned(8)
41+
_ui16 = lambda: IntegerType.get_unsigned(16)
42+
_ui32 = lambda: IntegerType.get_unsigned(32)
43+
_ui64 = lambda: IntegerType.get_unsigned(64)
4544

46-
_f16_t = lambda: F16Type.get()
47-
_f32_t = lambda: F32Type.get()
48-
_f64_t = lambda: F64Type.get()
49-
_bf16_t = lambda: BF16Type.get()
45+
_f16 = lambda: F16Type.get()
46+
_f32 = lambda: F32Type.get()
47+
_f64 = lambda: F64Type.get()
48+
_bf16 = lambda: BF16Type.get()
5049

51-
_f8e5m2_t = lambda: Float8E5M2Type.get()
52-
_f8e4m3_t = lambda: Float8E4M3FNType.get()
53-
_f8e4m3b11fnuz_t = lambda: Float8E4M3B11FNUZType.get()
50+
_f8e5m2 = lambda: Float8E5M2Type.get()
51+
_f8e4m3 = lambda: Float8E4M3FNType.get()
52+
_f8e4m3b11fnuz = lambda: Float8E4M3B11FNUZType.get()
5453

55-
_cmp16_t = lambda: ComplexType.get(_f16_t())
56-
_cmp32_t = lambda: ComplexType.get(_f32_t())
57-
_cmp64_t = lambda: ComplexType.get(_f64_t())
54+
_cmp16 = lambda: ComplexType.get(_f16())
55+
_cmp32 = lambda: ComplexType.get(_f32())
56+
_cmp64 = lambda: ComplexType.get(_f64())
5857

59-
_none_t = lambda: NoneType.get()
58+
_none = lambda: NoneType.get()
6059

61-
opaque_t = lambda dialect_namespace, buffer: OpaqueType.get(dialect_namespace, buffer)
60+
opaque = lambda dialect_namespace, buffer: OpaqueType.get(dialect_namespace, buffer)
6261

6362

64-
def _placeholder_opaque_t():
65-
return opaque_t("scf", "placeholder")
63+
def placeholder_opaque():
64+
return opaque("scf", "placeholder")
6665

6766

6867
_name_to_type = {
69-
"index_t": _index_t,
70-
"bool_t": _bool_t,
71-
"i8_t": _i8_t,
72-
"i16_t": _i16_t,
73-
"i32_t": _i32_t,
74-
"i64_t": _i64_t,
75-
"si8_t": _si8_t,
76-
"si16_t": _si16_t,
77-
"si32_t": _si32_t,
78-
"si64_t": _si64_t,
79-
"ui8_t": _ui8_t,
80-
"ui16_t": _ui16_t,
81-
"ui32_t": _ui32_t,
82-
"ui64_t": _ui64_t,
83-
"f16_t": _f16_t,
84-
"f32_t": _f32_t,
85-
"f64_t": _f64_t,
86-
"bf16_t": _bf16_t,
87-
"f8e5m2_t": _f8e5m2_t,
88-
"f8e4m3_t": _f8e4m3_t,
89-
"f8e4m3b11fnuz_t": _f8e4m3b11fnuz_t,
90-
"cmp16_t": _cmp16_t,
91-
"cmp32_t": _cmp32_t,
92-
"cmp64_t": _cmp64_t,
93-
"none_t": _none_t,
68+
"index": _index,
69+
"bool": _bool,
70+
"i8": _i8,
71+
"i16": _i16,
72+
"i32": _i32,
73+
"i64": _i64,
74+
"si8": _si8,
75+
"si16": _si16,
76+
"si32": _si32,
77+
"si64": _si64,
78+
"ui8": _ui8,
79+
"ui16": _ui16,
80+
"ui32": _ui32,
81+
"ui64": _ui64,
82+
"f16": _f16,
83+
"f32": _f32,
84+
"f64": _f64,
85+
"bf16": _bf16,
86+
"f8e5m2": _f8e5m2,
87+
"f8e4m3": _f8e4m3,
88+
"f8e4m3b11fnuz": _f8e4m3b11fnuz,
89+
"cmp16": _cmp16,
90+
"cmp32": _cmp32,
91+
"cmp64": _cmp64,
92+
"none": _none,
9493
}
9594

9695

@@ -102,19 +101,19 @@ def __getattr__(name):
102101

103102

104103
_np_dtype_to_mlir_type_ctor = {
105-
np.int8: _i8_t,
106-
np.int16: _i16_t,
107-
np.int32: _i32_t,
104+
np.int8: _i8,
105+
np.int16: _i16,
106+
np.int32: _i32,
108107
# windows
109-
np.intc: _i32_t,
110-
np.int64: _i64_t,
108+
np.intc: _i32,
109+
np.int64: _i64,
111110
# is technically wrong i guess but numpy by default casts python scalars to this
112111
# so to support passing lists of ints we map to index type
113-
np.longlong: _index_t,
114-
np.uintp: _index_t,
115-
np.float16: _f16_t,
116-
np.float32: _f32_t,
117-
np.float64: _f64_t,
112+
np.longlong: _index,
113+
np.uintp: _index,
114+
np.float16: _f16,
115+
np.float32: _f32,
116+
np.float64: _f64,
118117
}
119118

120119
_mlir_type_ctor_to_np_dtype = lambda: {
@@ -146,16 +145,16 @@ def infer_mlir_type(
146145
MLIR type corresponding to py_val.
147146
"""
148147
if isinstance(py_val, bool):
149-
return _bool_t()
148+
return _bool()
150149
elif isinstance(py_val, int):
151150
if -(2 ** 31) <= py_val < 2 ** 31:
152-
return _i32_t()
151+
return _i32()
153152
elif 2 ** 31 <= py_val < 2 ** 32:
154-
return _ui32_t()
153+
return _ui32()
155154
elif -(2 ** 63) <= py_val < 2 ** 63:
156-
return _i64_t()
155+
return _i64()
157156
elif 2 ** 63 <= py_val < 2 ** 64:
158-
return _ui64_t()
157+
return _ui64()
159158
else:
160159
raise RuntimeError(f"Nonrepresentable integer {py_val}.")
161160
elif isinstance(py_val, float):
@@ -165,9 +164,9 @@ def infer_mlir_type(
165164
or py_val != py_val # NaN
166165
or np.finfo(np.float32).min <= abs(py_val) <= np.finfo(np.float32).max
167166
):
168-
return _f32_t()
167+
return _f32()
169168
else:
170-
return _f64_t()
169+
return _f64()
171170
elif isinstance(py_val, np.ndarray):
172171
dtype = np_dtype_to_mlir_type(py_val.dtype.type)
173172
return RankedTensorType.get(py_val.shape, dtype)
@@ -177,9 +176,9 @@ def infer_mlir_type(
177176
)
178177

179178

180-
def shaped_t(*args, element_type: Type = None, type_constructor=None):
179+
def shaped(*args, element_type: Type = None, type_constructor=None):
181180
if type_constructor is None:
182-
raise ValueError("shaped_t is an abstract base class - cannot be constructed")
181+
raise ValueError("shaped is an abstract base class - cannot be constructed")
183182
if (element_type is None and args and not isinstance(args[-1], Type)) or (
184183
args and isinstance(args[-1], Type) and element_type is not None
185184
):
@@ -198,33 +197,33 @@ def shaped_t(*args, element_type: Type = None, type_constructor=None):
198197
return type_constructor(type)
199198

200199

201-
def vector_t(*args, element_type: Type = None):
202-
return shaped_t(*args, element_type=element_type, type_constructor=VectorType.get)
200+
def vector(*args, element_type: Type = None):
201+
return shaped(*args, element_type=element_type, type_constructor=VectorType.get)
203202

204203

205-
def tensor_t(*args, element_type: Type = None):
204+
def tensor(*args, element_type: Type = None):
206205
if not len(args) or len(args) == 1 and isinstance(args[-1], Type):
207-
return shaped_t(
206+
return shaped(
208207
*args, element_type=element_type, type_constructor=UnrankedTensorType.get
209208
)
210209
else:
211-
return shaped_t(
210+
return shaped(
212211
*args, element_type=element_type, type_constructor=RankedTensorType.get
213212
)
214213

215214

216-
def memref_t(*args, element_type: Type = None, memory_space: int = None):
215+
def memref(*args, element_type: Type = None, memory_space: int = None):
217216
if memory_space is None:
218217
memory_space = 0
219218
memory_space = Attribute.parse(str(memory_space))
220219
if not len(args) or len(args) == 1 and isinstance(args[-1], Type):
221-
return shaped_t(
220+
return shaped(
222221
*args,
223222
element_type=element_type,
224223
type_constructor=partial(UnrankedMemRefType.get, memory_space=memory_space),
225224
)
226225
else:
227-
return shaped_t(
226+
return shaped(
228227
*args,
229228
element_type=element_type,
230229
type_constructor=partial(MemRefType.get, memory_space=memory_space),

tests/test_location_tracking.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,13 @@ def get_asm(operation):
2929
def test_if_replace_yield_5(ctx: MLIRContext):
3030
@canonicalize(using=canonicalizer)
3131
def iffoo():
32-
one = constant(1.0, T.f32_t)
33-
two = constant(2.0, T.f32_t)
32+
one = constant(1.0, T.f32)
33+
two = constant(2.0, T.f32)
3434
if one < two:
35-
three = constant(3.0, T.f32_t)
35+
three = constant(3.0, T.f32)
3636
res1, res2, res3 = yield three, three, three
3737
else:
38-
four = constant(4.0, T.f32_t)
38+
four = constant(4.0, T.f32)
3939
res1, res2, res3 = yield four, four, four
4040
return
4141

@@ -71,11 +71,11 @@ def iffoo():
7171

7272

7373
def test_block_args(ctx: MLIRContext):
74-
one = constant(1, T.index_t)
75-
two = constant(2, T.index_t)
74+
one = constant(1, T.index)
75+
two = constant(2, T.index)
7676

77-
@generate(T.tensor_t(S, 3, S, T.f32_t), dynamic_extents=[one, two])
78-
def demo_fun1(i: T.index_t, j: T.index_t, k: T.index_t):
77+
@generate(T.tensor(S, 3, S, T.f32), dynamic_extents=[one, two])
78+
def demo_fun1(i: T.index, j: T.index, k: T.index):
7979
one = constant(1.0)
8080
tensor_yield(one)
8181

tests/test_operator_overloading.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,9 @@ def test_tensor_arithmetic(ctx: MLIRContext):
6666
three = one + two
6767
assert isinstance(three, Scalar)
6868

69-
ten1 = empty((10, 10, 10), T.f32_t)
69+
ten1 = empty((10, 10, 10), T.f32)
7070
assert isinstance(ten1, Tensor)
71-
ten2 = empty((10, 10, 10), T.f32_t)
71+
ten2 = empty((10, 10, 10), T.f32)
7272
assert isinstance(ten2, Tensor)
7373
ten3 = ten1 + ten2
7474
assert isinstance(ten3, Tensor)

tests/test_regions.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
# noinspection PyUnresolvedReferences
1515
from mlir_utils.testing import mlir_ctx as ctx, filecheck, MLIRContext
16-
from mlir_utils.types import tensor_t
16+
from mlir_utils.types import tensor
1717

1818
# needed since the fix isn't defined here nor conftest.py
1919
pytest.mark.usefixtures("ctx")
@@ -95,11 +95,11 @@ def demo_fun1():
9595

9696

9797
def test_block_args(ctx: MLIRContext):
98-
one = constant(1, T.index_t)
99-
two = constant(2, T.index_t)
98+
one = constant(1, T.index)
99+
two = constant(2, T.index)
100100

101-
@generate(tensor_t(S, 3, S, T.f32_t), dynamic_extents=[one, two])
102-
def demo_fun1(i: T.index_t, j: T.index_t, k: T.index_t):
101+
@generate(tensor(S, 3, S, T.f32), dynamic_extents=[one, two])
102+
def demo_fun1(i: T.index, j: T.index, k: T.index):
103103
one = constant(1.0)
104104
tensor_yield(one)
105105

tests/test_scf.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -286,11 +286,11 @@ def test_if_ctx_manager(ctx: MLIRContext):
286286
# @formatter:off
287287
one = constant(1.0)
288288
two = constant(2.0)
289-
with if_ctx_manager(one < two, results=[T._placeholder_opaque_t()]) as if_op: # if
289+
with if_ctx_manager(one < two, results=[T.placeholder_opaque()]) as if_op: # if
290290
three = constant(3.0)
291291
res = yield_(three)
292292
with else_ctx_manager(if_op) as _: # else
293-
with if_ctx_manager(one < two, results=[T._placeholder_opaque_t()]) as if_op: # if
293+
with if_ctx_manager(one < two, results=[T.placeholder_opaque()]) as if_op: # if
294294
three = constant(4.0)
295295
res = yield_(three)
296296
with else_ctx_manager(if_op) as _: # else

0 commit comments

Comments
 (0)