@@ -45,8 +45,12 @@ class RAIIMLIRContext:
45
45
context : ir .Context
46
46
location : ir .Location
47
47
48
- def __init__ (self , location : Optional [ir .Location ] = None ):
48
+ def __init__ (
49
+ self , location : Optional [ir .Location ] = None , allow_unregistered_dialects = False
50
+ ):
49
51
self .context = ir .Context ()
52
+ if allow_unregistered_dialects :
53
+ self .context .allow_unregistered_dialects = True
50
54
self .context .__enter__ ()
51
55
if location is None :
52
56
location = ir .Location .unknown ()
@@ -61,6 +65,36 @@ def __del__(self):
61
65
assert ir .Context is not self .context
62
66
63
67
68
+ class RAIIMLIRContextModule :
69
+ context : ir .Context
70
+ location : ir .Location
71
+ insertion_point : ir .InsertionPoint
72
+ module : ir .Module
73
+
74
+ def __init__ (
75
+ self , location : Optional [ir .Location ] = None , allow_unregistered_dialects = False
76
+ ):
77
+ self .context = ir .Context ()
78
+ if allow_unregistered_dialects :
79
+ self .context .allow_unregistered_dialects = True
80
+ self .context .__enter__ ()
81
+ if location is None :
82
+ location = ir .Location .unknown ()
83
+ self .location = location
84
+ self .location .__enter__ ()
85
+ self .module = ir .Module .create ()
86
+ self .insertion_point = ir .InsertionPoint (self .module .body )
87
+ self .insertion_point .__enter__ ()
88
+
89
+ def __del__ (self ):
90
+ self .insertion_point .__exit__ (None , None , None )
91
+ self .location .__exit__ (None , None , None )
92
+ self .context .__exit__ (None , None , None )
93
+ # i guess the extension gets destroyed before this object sometimes?
94
+ if ir is not None :
95
+ assert ir .Context is not self .context
96
+
97
+
64
98
class ExplicitlyManagedModule :
65
99
module : ir .Module
66
100
_ip : ir .InsertionPoint
0 commit comments