Skip to content

Commit 8aba4e7

Browse files
authored
add RAIIMLIRContextModule (#111)
1 parent e7592ab commit 8aba4e7

File tree

2 files changed

+62
-2
lines changed

2 files changed

+62
-2
lines changed

mlir/extras/context.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,12 @@ class RAIIMLIRContext:
4545
context: ir.Context
4646
location: ir.Location
4747

48-
def __init__(self, location: Optional[ir.Location] = None):
48+
def __init__(
49+
self, location: Optional[ir.Location] = None, allow_unregistered_dialects=False
50+
):
4951
self.context = ir.Context()
52+
if allow_unregistered_dialects:
53+
self.context.allow_unregistered_dialects = True
5054
self.context.__enter__()
5155
if location is None:
5256
location = ir.Location.unknown()
@@ -61,6 +65,36 @@ def __del__(self):
6165
assert ir.Context is not self.context
6266

6367

68+
class RAIIMLIRContextModule:
69+
context: ir.Context
70+
location: ir.Location
71+
insertion_point: ir.InsertionPoint
72+
module: ir.Module
73+
74+
def __init__(
75+
self, location: Optional[ir.Location] = None, allow_unregistered_dialects=False
76+
):
77+
self.context = ir.Context()
78+
if allow_unregistered_dialects:
79+
self.context.allow_unregistered_dialects = True
80+
self.context.__enter__()
81+
if location is None:
82+
location = ir.Location.unknown()
83+
self.location = location
84+
self.location.__enter__()
85+
self.module = ir.Module.create()
86+
self.insertion_point = ir.InsertionPoint(self.module.body)
87+
self.insertion_point.__enter__()
88+
89+
def __del__(self):
90+
self.insertion_point.__exit__(None, None, None)
91+
self.location.__exit__(None, None, None)
92+
self.context.__exit__(None, None, None)
93+
# i guess the extension gets destroyed before this object sometimes?
94+
if ir is not None:
95+
assert ir.Context is not self.context
96+
97+
6498
class ExplicitlyManagedModule:
6599
module: ir.Module
66100
_ip: ir.InsertionPoint

tests/test_func.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import inspect
22
import sys
3+
import threading
34
from textwrap import dedent
45
from typing import TypeVar
56

@@ -8,7 +9,7 @@
89
import mlir.extras.types as T
910

1011
from mlir.extras.ast.canonicalize import canonicalize
11-
from mlir.extras.context import mlir_mod_ctx
12+
from mlir.extras.context import mlir_mod_ctx, RAIIMLIRContextModule
1213
from mlir.extras.dialects.ext.arith import constant
1314
from mlir.extras.dialects.ext.func import func
1415
from mlir.extras.dialects.ext import linalg, arith, scf, memref
@@ -294,3 +295,28 @@ def mat_product_kernel(
294295
"""
295296
)
296297
filecheck(correct, ctx.module)
298+
299+
300+
def test_raii_mlir_context_module():
301+
tls = threading.local()
302+
tls.ctx = RAIIMLIRContextModule()
303+
304+
@func
305+
def demo_fun1():
306+
one = constant(1)
307+
return one
308+
309+
assert hasattr(demo_fun1, "emit")
310+
assert inspect.ismethod(demo_fun1.emit)
311+
demo_fun1.emit()
312+
correct = dedent(
313+
"""\
314+
module {
315+
func.func @demo_fun1() -> i32 {
316+
%c1_i32 = arith.constant 1 : i32
317+
return %c1_i32 : i32
318+
}
319+
}
320+
"""
321+
)
322+
filecheck(correct, tls.ctx.module)

0 commit comments

Comments
 (0)