Skip to content

Commit 837f98e

Browse files
committed
Respect predefined modes in get_default_mode
Also make linker and optimizer non-mutable config as the mode is cached after using them for the first time.
1 parent a0fe30d commit 837f98e

File tree

4 files changed

+38
-30
lines changed

4 files changed

+38
-30
lines changed

pytensor/compile/mode.py

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -503,46 +503,45 @@ def get_mode(orig_string):
503503
if not isinstance(string, str):
504504
return string # it is hopefully already a mode...
505505

506+
if string in predefined_modes:
507+
return predefined_modes[string]
508+
509+
if string not in ("Mode", "DebugMode", "DEBUG_MODE", "NanGuardMode"):
510+
raise ValueError(f"No predefined mode exist for string: {string}")
511+
506512
global instantiated_default_mode
507513
# The default mode is cached. However, config.mode can change
508514
# If instantiated_default_mode has the right class, use it.
515+
509516
if orig_string is None and instantiated_default_mode:
510-
if string in predefined_modes:
511-
default_mode_class = predefined_modes[string].__class__.__name__
512-
else:
513-
default_mode_class = string
517+
default_mode_class = string
518+
# FIXME: This is flawed, we should use proper object comparison.
514519
if instantiated_default_mode.__class__.__name__ == default_mode_class:
515520
return instantiated_default_mode
516521

517-
if string in ("Mode", "DebugMode", "NanGuardMode"):
518-
if string == "DebugMode":
519-
# need to import later to break circular dependency.
520-
from .debugmode import DebugMode
522+
if string in ("DebugMode", "DEBUG_MODE"):
523+
# need to import later to break circular dependency.
524+
from .debugmode import DebugMode
521525

522-
# DebugMode use its own linker.
523-
ret = DebugMode(optimizer=config.optimizer)
524-
elif string == "NanGuardMode":
525-
# need to import later to break circular dependency.
526-
from .nanguardmode import NanGuardMode
526+
# DebugMode use its own linker.
527+
ret = DebugMode(optimizer=config.optimizer)
528+
elif string == "NanGuardMode":
529+
# need to import later to break circular dependency.
530+
from .nanguardmode import NanGuardMode
527531

528-
# NanGuardMode use its own linker.
529-
ret = NanGuardMode(True, True, True, optimizer=config.optimizer)
530-
else:
531-
# TODO: Can't we look up the name and invoke it rather than using eval here?
532-
ret = eval(string + "(linker=config.linker, optimizer=config.optimizer)")
533-
elif string in predefined_modes:
534-
ret = predefined_modes[string]
532+
# NanGuardMode use its own linker.
533+
ret = NanGuardMode(True, True, True, optimizer=config.optimizer)
535534
else:
536-
raise Exception(f"No predefined mode exist for string: {string}")
535+
ret = Mode(linker=config.linker, optimizer=config.optimizer)
537536

538537
if orig_string is None:
539-
# Build and cache the default mode
540538
if config.optimizer_excluding:
541539
ret = ret.excluding(*config.optimizer_excluding.split(":"))
542540
if config.optimizer_including:
543541
ret = ret.including(*config.optimizer_including.split(":"))
544542
if config.optimizer_requiring:
545543
ret = ret.requiring(*config.optimizer_requiring.split(":"))
544+
# Override the cache with the new class mode
546545
instantiated_default_mode = ret
547546

548547
return ret

pytensor/configdefaults.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -388,7 +388,8 @@ def add_compile_configvars():
388388
config.add(
389389
"linker",
390390
"Default linker used if the pytensor flags mode is Mode",
391-
EnumStr("cvm", linker_options),
391+
# Not mutable because the default mode is cached after the first use.
392+
EnumStr("cvm", linker_options, mutable=False),
392393
in_c_key=False,
393394
)
394395

@@ -411,6 +412,7 @@ def add_compile_configvars():
411412
EnumStr(
412413
"o4",
413414
["o3", "o2", "o1", "unsafe", "fast_run", "fast_compile", "merge", "None"],
415+
mutable=False, # Not mutable because the default mode is cached after the first use.
414416
),
415417
in_c_key=False,
416418
)

tests/compile/function/test_types.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1105,14 +1105,10 @@ def test_optimizations_preserved(self):
11051105
((a.T.T) * (dot(xm, (sm.T.T.T)) + x).T * (x / x) + s),
11061106
)
11071107
old_default_mode = config.mode
1108-
old_default_opt = config.optimizer
1109-
old_default_link = config.linker
11101108
try:
11111109
try:
11121110
str_f = pickle.dumps(f, protocol=-1)
1113-
config.mode = "Mode"
1114-
config.linker = "py"
1115-
config.optimizer = "None"
1111+
config.mode = "NUMBA"
11161112
g = pickle.loads(str_f)
11171113
# print g.maker.mode
11181114
# print compile.mode.default_mode
@@ -1121,8 +1117,6 @@ def test_optimizations_preserved(self):
11211117
g = "ok"
11221118
finally:
11231119
config.mode = old_default_mode
1124-
config.optimizer = old_default_opt
1125-
config.linker = old_default_link
11261120

11271121
if g == "ok":
11281122
return

tests/compile/test_mode.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from pytensor.graph.features import NoOutputFromInplace
1414
from pytensor.graph.rewriting.db import RewriteDatabaseQuery, SequenceDB
1515
from pytensor.link.basic import LocalLinker
16+
from pytensor.link.jax import JAXLinker
1617
from pytensor.tensor.math import dot, tanh
1718
from pytensor.tensor.type import matrix, vector
1819

@@ -142,3 +143,15 @@ class MyLinker(LocalLinker):
142143
test_mode = Mode(linker=MyLinker())
143144
with pytest.raises(Exception):
144145
get_target_language(test_mode)
146+
147+
148+
def test_predefined_modes_respected():
149+
default_mode = get_default_mode()
150+
assert not isinstance(default_mode.linker, JAXLinker)
151+
152+
with config.change_flags(mode="JAX"):
153+
jax_mode = get_default_mode()
154+
assert isinstance(jax_mode.linker, JAXLinker)
155+
156+
default_mode_again = get_default_mode()
157+
assert not isinstance(default_mode_again.linker, JAXLinker)

0 commit comments

Comments
 (0)