Skip to content

Commit 22d36f8

Browse files
committed
Address review comments
1 parent cebba8e commit 22d36f8

File tree

3 files changed

+35
-16
lines changed

3 files changed

+35
-16
lines changed

mlir/python/mlir/_mlir_libs/__init__.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -156,16 +156,23 @@ def __init__(self, load_on_create_dialects=None, *args, **kwargs):
156156
if not disable_multithreading:
157157
self.enable_multithreading(True)
158158
if load_on_create_dialects is not None:
159-
logger.debug("Loading all dialects from load_on_create_dialects arg %r", _load_on_create_dialects)
159+
logger.debug(
160+
"Loading all dialects from load_on_create_dialects arg %r",
161+
load_on_create_dialects,
162+
)
160163
for dialect in load_on_create_dialects:
161-
# Load dialect.
164+
# This triggers loading the dialect into the context.
162165
_ = self.dialects[dialect]
163166
else:
164167
if disable_load_all_available_dialects:
165-
if _load_on_create_dialects:
166-
logger.debug("Loading all dialects from global load_on_create_dialects %r", _load_on_create_dialects)
167-
for dialect in _load_on_create_dialects:
168-
# Load dialect.
168+
dialects = get_load_on_create_dialects()
169+
if dialects:
170+
logger.debug(
171+
"Loading all dialects from global load_on_create_dialects %r",
172+
dialects,
173+
)
174+
for dialect in dialects:
175+
# This triggers loading the dialect into the context.
169176
_ = self.dialects[dialect]
170177
else:
171178
logger.debug("Loading all available dialects")

mlir/python/mlir/ir.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,11 @@
55
from ._mlir_libs._mlir.ir import *
66
from ._mlir_libs._mlir.ir import _GlobalDebug
77
from ._mlir_libs._mlir import register_type_caster, register_value_caster
8-
from ._mlir_libs import get_dialect_registry, append_load_on_create_dialect, get_load_on_create_dialects
8+
from ._mlir_libs import (
9+
get_dialect_registry,
10+
append_load_on_create_dialect,
11+
get_load_on_create_dialects,
12+
)
913

1014

1115
# Convenience decorator for registering user-friendly Attribute builders.

mlir/test/python/ir/dialects.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def testDialectLoadOnCreate():
129129
with Context(load_on_create_dialects=[]) as ctx:
130130
ctx.emit_error_diagnostics = True
131131
ctx.allow_unregistered_dialects = True
132-
132+
133133
def callback(d):
134134
# CHECK: DIAGNOSTIC
135135
# CHECK-SAME: op created with unregistered dialect
@@ -139,13 +139,21 @@ def callback(d):
139139
handler = ctx.attach_diagnostic_handler(callback)
140140
loc = Location.unknown(ctx)
141141
try:
142-
op = Operation.create("arith.addi", loc=loc)
143-
ctx.allow_unregistered_dialects = False
144-
op.verify()
142+
op = Operation.create("arith.addi", loc=loc)
143+
ctx.allow_unregistered_dialects = False
144+
op.verify()
145145
except MLIRError as e:
146-
pass
147-
148-
with Context(load_on_create_dialects=["func"]) as ctx:
149-
loc = Location.unknown(ctx)
150-
fn = Operation.create("func.func", loc=loc)
146+
pass
151147

148+
with Context(load_on_create_dialects=["func"]) as ctx:
149+
loc = Location.unknown(ctx)
150+
fn = Operation.create("func.func", loc=loc)
151+
152+
# TODO: This may require an update if a site wide policy is set.
153+
# CHECK: Load on create: []
154+
print(f"Load on create: {get_load_on_create_dialects()}")
155+
append_load_on_create_dialect("func")
156+
# CHECK: Load on create:
157+
# CHECK-SAME: func
158+
print(f"Load on create: {get_load_on_create_dialects()}")
159+
print(get_load_on_create_dialects())

0 commit comments

Comments
 (0)