Skip to content

Commit 7af45b1

Browse files
committed
remove default context
1 parent 1cd37d6 commit 7af45b1

File tree

12 files changed

+125
-109
lines changed

12 files changed

+125
-109
lines changed

mlir_utils/__init__.py

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,3 @@
1-
import atexit
2-
31
from ._configuration.configuration import alias_upstream_bindings
42

5-
if alias_upstream_bindings():
6-
from mlir import ir
7-
8-
DefaultContext = ir.Context()
9-
# Push a default context onto the context stack at import time.
10-
DefaultContext.__enter__()
11-
DefaultContext.allow_unregistered_dialects = False
12-
13-
DefaultLocation = ir.Location.unknown()
14-
DefaultLocation.__enter__()
15-
16-
@atexit.register
17-
def __exit_ctxt():
18-
DefaultContext.__exit__(None, None, None)
19-
20-
@atexit.register
21-
def __exit_loc():
22-
DefaultLocation.__exit__(None, None, None)
3+
alias_upstream_bindings()

mlir_utils/context.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44

55
import mlir.ir
66

7-
from mlir_utils import DefaultContext
8-
97

108
@dataclass
119
class MLIRContext:
@@ -19,7 +17,7 @@ def __str__(self):
1917
@contextmanager
2018
def mlir_mod_ctx(
2119
src: Optional[str] = None,
22-
context: mlir.ir.Context = DefaultContext,
20+
context: mlir.ir.Context = None,
2321
location: mlir.ir.Location = None,
2422
allow_unregistered_dialects=False,
2523
) -> MLIRContext:

mlir_utils/dialects/ext/arith.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
except ModuleNotFoundError:
3636
pass
3737

38-
from mlir_utils.types import infer_mlir_type, MLIR_TYPE_TO_NP_DTYPE
38+
from mlir_utils.types import infer_mlir_type, mlir_type_to_np_dtype
3939

4040

4141
def constant(
@@ -69,7 +69,7 @@ def constant(
6969
ranked_tensor_type = RankedTensorType(type)
7070
value = np.ones(
7171
ranked_tensor_type.shape,
72-
dtype=MLIR_TYPE_TO_NP_DTYPE()[ranked_tensor_type.element_type],
72+
dtype=mlir_type_to_np_dtype(ranked_tensor_type.element_type),
7373
)
7474
assert type is not None
7575

mlir_utils/dialects/ext/scf.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from mlir.dialects.scf import IfOp, ForOp
1010
from mlir.ir import InsertionPoint, Value, OpResultList, OpResult
1111

12+
import mlir_utils.types as T
1213
from mlir_utils.ast.canonicalize import (
1314
StrictTransformer,
1415
Canonicalizer,
@@ -18,7 +19,6 @@
1819
from mlir_utils.ast.util import ast_call
1920
from mlir_utils.dialects.ext.arith import constant
2021
from mlir_utils.dialects.scf import yield_ as yield__
21-
from mlir_utils.types import opaque_t
2222
from mlir_utils.util import (
2323
region_op,
2424
maybe_cast,
@@ -101,8 +101,6 @@ def _if(cond, results_=None, *, has_else=False, loc=None, ip=None):
101101

102102
if_ = region_op(_if, terminator=yield__)
103103

104-
_placeholder_opaque_t = opaque_t("scf", "placeholder")
105-
106104

107105
class IfStack:
108106
__current_if_op: list[IfOp] = []
@@ -175,7 +173,7 @@ def yield_(*args):
175173

176174
assert len(results) == len(unpacked_args), f"{results=}, {unpacked_args=}"
177175
for i, r in enumerate(results):
178-
if r.type == _placeholder_opaque_t:
176+
if r.type == T._placeholder_opaque_t():
179177
r.set_type(unpacked_args[i].type)
180178

181179
yield_(*args)
@@ -325,13 +323,13 @@ def insert_with_results(
325323
), f"conditional with := must explicitly yield on last line"
326324
yield_expr = last_statement.body[0]
327325
if m.matches(yield_expr.value, m.Call(func=m.Name(stack_yield.__name__))):
328-
results = [cst.Element(cst.Name("_placeholder_opaque_t"))] * len(
326+
results = [cst.Element(ast_call(T._placeholder_opaque_t.__name__))] * len(
329327
yield_expr.value.args
330328
)
331329
elif m.matches(yield_expr.value.value, m.Name()):
332-
results = [cst.Element(cst.Name("_placeholder_opaque_t"))]
330+
results = [cst.Element(ast_call(T._placeholder_opaque_t.__name__))]
333331
elif m.matches(yield_expr.value.value, m.Tuple()):
334-
results = [cst.Element(cst.Name("_placeholder_opaque_t"))] * len(
332+
results = [cst.Element(ast_call(T._placeholder_opaque_t.__name__))] * len(
335333
yield_expr.value.value.elements
336334
)
337335
results = cst.Tuple(results)
@@ -422,14 +420,13 @@ def patch_bytecode(self, code: ConcreteBytecode, f):
422420
str(OpCode.NOP), lineno=c.lineno, location=c.location
423421
)
424422

425-
# TODO(max): this is bad
426423
f.__globals__[else_.__name__] = else_
427424
f.__globals__[end_branch.__name__] = end_branch
428425
f.__globals__[end_if.__name__] = end_if
429426
f.__globals__[stack_if.__name__] = stack_if
430427
f.__globals__[stack_yield.__name__] = stack_yield
431428
f.__globals__[yield_.__name__] = yield_
432-
f.__globals__["_placeholder_opaque_t"] = _placeholder_opaque_t
429+
f.__globals__[T._placeholder_opaque_t.__name__] = T._placeholder_opaque_t
433430
return code
434431

435432

mlir_utils/types.py

Lines changed: 65 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import sys
12
from functools import partial
23
from typing import Union
34

@@ -19,33 +20,70 @@
1920
OpaqueType,
2021
)
2122

22-
index_t = IndexType.get()
23-
bool_t = IntegerType.get_signless(1)
24-
i8_t = IntegerType.get_signless(8)
25-
i16_t = IntegerType.get_signless(16)
26-
i32_t = IntegerType.get_signless(32)
27-
i64_t = IntegerType.get_signless(64)
28-
f16_t = F16Type.get()
29-
f32_t = F32Type.get()
30-
f64_t = F64Type.get()
31-
bf16_t = BF16Type.get()
23+
_index_t = lambda: IndexType.get()
24+
_bool_t = lambda: IntegerType.get_signless(1)
25+
_i8_t = lambda: IntegerType.get_signless(8)
26+
_i16_t = lambda: IntegerType.get_signless(16)
27+
_i32_t = lambda: IntegerType.get_signless(32)
28+
_i64_t = lambda: IntegerType.get_signless(64)
29+
_f16_t = lambda: F16Type.get()
30+
_f32_t = lambda: F32Type.get()
31+
_f64_t = lambda: F64Type.get()
32+
_bf16_t = lambda: BF16Type.get()
3233
opaque_t = lambda dialect_namespace, buffer: OpaqueType.get(dialect_namespace, buffer)
3334

34-
NP_DTYPE_TO_MLIR_TYPE = lambda: {
35-
np.int8: i8_t,
36-
np.int16: i16_t,
37-
np.int32: i32_t,
38-
np.int64: i64_t,
39-
# this is techincally wrong i guess but numpy by default casts python scalars to this
40-
# so to support passing lists of ints we map this to index type
41-
np.longlong: index_t,
42-
np.uintp: index_t,
43-
np.float16: f16_t,
44-
np.float32: f32_t,
45-
np.float64: f64_t,
35+
36+
def _placeholder_opaque_t():
37+
return opaque_t("scf", "placeholder")
38+
39+
40+
_name_to_type = {
41+
"index_t": _index_t,
42+
"bool_t": _bool_t,
43+
"i8_t": _i8_t,
44+
"i16_t": _i16_t,
45+
"i32_t": _i32_t,
46+
"i64_t": _i64_t,
47+
"f16_t": _f16_t,
48+
"f32_t": _f32_t,
49+
"f64_t": _f64_t,
50+
"bf16_t": _bf16_t,
51+
}
52+
53+
54+
def __getattr__(name):
55+
if name in _name_to_type:
56+
return _name_to_type[name]()
57+
# this kicks it to the default module attribute lookup (i.e., functions defined below and such)
58+
return None
59+
60+
61+
_np_dtype_to_mlir_type_ctor = {
62+
np.int8: _i8_t,
63+
np.int16: _i16_t,
64+
np.int32: _i32_t,
65+
np.int64: _i64_t,
66+
# is technically wrong i guess but numpy by default casts python scalars to this
67+
# so to support passing lists of ints we map to index type
68+
np.longlong: _index_t,
69+
np.uintp: _index_t,
70+
np.float16: _f16_t,
71+
np.float32: _f32_t,
72+
np.float64: _f64_t,
4673
}
4774

48-
MLIR_TYPE_TO_NP_DTYPE = lambda: {v: k for k, v in NP_DTYPE_TO_MLIR_TYPE().items()}
75+
_mlir_type_ctor_to_np_dtype = lambda: {
76+
v: k for k, v in _np_dtype_to_mlir_type_ctor.items()
77+
}
78+
79+
80+
def np_dtype_to_mlir_type(np_dtype):
81+
return _np_dtype_to_mlir_type_ctor[np_dtype]()
82+
83+
84+
def mlir_type_to_np_dtype(mlir_type):
85+
_mlir_type_to_np_dtype = {v(): k for k, v in _np_dtype_to_mlir_type_ctor.items()}
86+
return _mlir_type_to_np_dtype[mlir_type]
4987

5088

5189
def infer_mlir_type(
@@ -62,13 +100,13 @@ def infer_mlir_type(
62100
MLIR type corresponding to py_val.
63101
"""
64102
if isinstance(py_val, bool):
65-
return bool_t
103+
return _bool_t()
66104
elif isinstance(py_val, int):
67-
return i64_t
105+
return _i64_t()
68106
elif isinstance(py_val, float):
69-
return f64_t
107+
return _f64_t()
70108
elif isinstance(py_val, np.ndarray):
71-
dtype = NP_DTYPE_TO_MLIR_TYPE()[py_val.dtype.type]
109+
dtype = np_dtype_to_mlir_type(py_val.dtype.type)
72110
return RankedTensorType.get(py_val.shape, dtype)
73111
else:
74112
raise NotImplementedError(

tests/test_location_tracking.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1+
from os import sep
12
from pathlib import Path
23
from textwrap import dedent
3-
from os import sep
4+
45
import pytest
56

7+
import mlir_utils.types as T
68
from mlir_utils.ast.canonicalize import canonicalize
79
from mlir_utils.dialects.ext.arith import constant
810
from mlir_utils.dialects.ext.scf import (
@@ -14,12 +16,10 @@
1416

1517
# noinspection PyUnresolvedReferences
1618
from mlir_utils.testing import mlir_ctx as ctx, filecheck, MLIRContext
17-
from mlir_utils.types import f64_t, index_t, tensor_t
1819

1920
# needed since the fix isn't defined here nor conftest.py
2021
pytest.mark.usefixtures("ctx")
2122

22-
2323
THIS_DIR = str(Path(__file__).parent.absolute())
2424

2525

@@ -34,7 +34,7 @@ def test_if_replace_yield_5(ctx: MLIRContext):
3434
def iffoo():
3535
one = constant(1.0)
3636
two = constant(2.0)
37-
if res := stack_if(one < two, (f64_t, f64_t, f64_t)):
37+
if res := stack_if(one < two, (T.f64_t, T.f64_t, T.f64_t)):
3838
three = constant(3.0)
3939
yield three, three, three
4040
else:
@@ -74,11 +74,11 @@ def iffoo():
7474

7575

7676
def test_block_args(ctx: MLIRContext):
77-
one = constant(1, index_t)
78-
two = constant(2, index_t)
77+
one = constant(1, T.index_t)
78+
two = constant(2, T.index_t)
7979

80-
@generate(tensor_t(S, 3, S, f64_t), dynamic_extents=[one, two])
81-
def demo_fun1(i: index_t, j: index_t, k: index_t):
80+
@generate(T.tensor_t(S, 3, S, T.f64_t), dynamic_extents=[one, two])
81+
def demo_fun1(i: T.index_t, j: T.index_t, k: T.index_t):
8282
one = constant(1.0)
8383
tensor_yield(one)
8484

tests/test_operator_overloading.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@
22

33
import pytest
44

5+
import mlir_utils.types as T
56
from mlir_utils.dialects.ext.arith import constant, Scalar
67
from mlir_utils.dialects.ext.tensor import Tensor, empty
78

89
# noinspection PyUnresolvedReferences
910
from mlir_utils.testing import mlir_ctx as ctx, filecheck, MLIRContext
10-
from mlir_utils.types import f64_t
1111

1212
# needed since the fix isn't defined here nor conftest.py
1313
pytest.mark.usefixtures("ctx")
@@ -21,9 +21,9 @@ def test_tensor_arithmetic(ctx: MLIRContext):
2121
three = one + two
2222
assert isinstance(three, Scalar)
2323

24-
ten1 = empty((10, 10, 10), f64_t)
24+
ten1 = empty((10, 10, 10), T.f64_t)
2525
assert isinstance(ten1, Tensor)
26-
ten2 = empty((10, 10, 10), f64_t)
26+
ten2 = empty((10, 10, 10), T.f64_t)
2727
assert isinstance(ten2, Tensor)
2828
ten3 = ten1 + ten2
2929
assert isinstance(ten3, Tensor)

tests/test_regions.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import pytest
44

5+
import mlir_utils.types as T
56
from mlir_utils.dialects.ext.arith import constant
67
from mlir_utils.dialects.ext.func import func
78
from mlir_utils.dialects.ext.tensor import S, rank
@@ -11,7 +12,7 @@
1112

1213
# noinspection PyUnresolvedReferences
1314
from mlir_utils.testing import mlir_ctx as ctx, filecheck, MLIRContext
14-
from mlir_utils.types import f64_t, index_t, tensor_t
15+
from mlir_utils.types import tensor_t
1516

1617
# needed since the fix isn't defined here nor conftest.py
1718
pytest.mark.usefixtures("ctx")
@@ -93,11 +94,11 @@ def demo_fun1():
9394

9495

9596
def test_block_args(ctx: MLIRContext):
96-
one = constant(1, index_t)
97-
two = constant(2, index_t)
97+
one = constant(1, T.index_t)
98+
two = constant(2, T.index_t)
9899

99-
@generate(tensor_t(S, 3, S, f64_t), dynamic_extents=[one, two])
100-
def demo_fun1(i: index_t, j: index_t, k: index_t):
100+
@generate(tensor_t(S, 3, S, T.f64_t), dynamic_extents=[one, two])
101+
def demo_fun1(i: T.index_t, j: T.index_t, k: T.index_t):
101102
one = constant(1.0)
102103
tensor_yield(one)
103104

0 commit comments

Comments
 (0)