Skip to content

Commit 60abfaa

Browse files
committed
Respect predefined modes in get_default_mode
Also make linker and optimizer non-mutable config as the mode is cached after using them for the first time.
1 parent a0fe30d commit 60abfaa

File tree

3 files changed

+38
-23
lines changed

3 files changed

+38
-23
lines changed

pytensor/compile/mode.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -503,40 +503,40 @@ def get_mode(orig_string):
503503
if not isinstance(string, str):
504504
return string # it is hopefully already a mode...
505505

506+
if string in predefined_modes:
507+
return predefined_modes[string]
508+
509+
if string not in ("Mode", "DebugMode", "DEBUG_MODE", "NanGuardMode"):
510+
raise ValueError(f"No predefined mode exist for string: {string}")
511+
506512
global instantiated_default_mode
507513
# The default mode is cached. However, config.mode can change
508514
# If instantiated_default_mode has the right class, use it.
515+
509516
if orig_string is None and instantiated_default_mode:
510-
if string in predefined_modes:
511-
default_mode_class = predefined_modes[string].__class__.__name__
512-
else:
513-
default_mode_class = string
517+
# This includes a string in ("Mode", "DebugMode", "DEBUG_MODE", "NanGuardMode")
518+
default_mode_class = string
519+
# FIXME: This is flawed, we should use proper object comparison.
514520
if instantiated_default_mode.__class__.__name__ == default_mode_class:
515521
return instantiated_default_mode
516522

517-
if string in ("Mode", "DebugMode", "NanGuardMode"):
518-
if string == "DebugMode":
519-
# need to import later to break circular dependency.
520-
from .debugmode import DebugMode
523+
if string in ("DebugMode", "DEBUG_MODE"):
524+
# need to import later to break circular dependency.
525+
from .debugmode import DebugMode
521526

522-
# DebugMode use its own linker.
523-
ret = DebugMode(optimizer=config.optimizer)
524-
elif string == "NanGuardMode":
525-
# need to import later to break circular dependency.
526-
from .nanguardmode import NanGuardMode
527+
# DebugMode use its own linker.
528+
ret = DebugMode(optimizer=config.optimizer)
529+
elif string == "NanGuardMode":
530+
# need to import later to break circular dependency.
531+
from .nanguardmode import NanGuardMode
527532

528-
# NanGuardMode use its own linker.
529-
ret = NanGuardMode(True, True, True, optimizer=config.optimizer)
530-
else:
531-
# TODO: Can't we look up the name and invoke it rather than using eval here?
532-
ret = eval(string + "(linker=config.linker, optimizer=config.optimizer)")
533-
elif string in predefined_modes:
534-
ret = predefined_modes[string]
533+
# NanGuardMode use its own linker.
534+
ret = NanGuardMode(True, True, True, optimizer=config.optimizer)
535535
else:
536-
raise Exception(f"No predefined mode exist for string: {string}")
536+
ret = Mode(linker=config.linker, optimizer=config.optimizer)
537537

538538
if orig_string is None:
539-
# Build and cache the default mode
539+
# Build and cache the ~~default~~ first requested mode
540540
if config.optimizer_excluding:
541541
ret = ret.excluding(*config.optimizer_excluding.split(":"))
542542
if config.optimizer_including:

pytensor/configdefaults.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -388,7 +388,8 @@ def add_compile_configvars():
388388
config.add(
389389
"linker",
390390
"Default linker used if the pytensor flags mode is Mode",
391-
EnumStr("cvm", linker_options),
391+
# Not mutable because the default mode is cached after the first use.
392+
EnumStr("cvm", linker_options, mutable=False),
392393
in_c_key=False,
393394
)
394395

@@ -411,6 +412,7 @@ def add_compile_configvars():
411412
EnumStr(
412413
"o4",
413414
["o3", "o2", "o1", "unsafe", "fast_run", "fast_compile", "merge", "None"],
415+
mutable=False, # Not mutable because the default mode is cached after the first use.
414416
),
415417
in_c_key=False,
416418
)

tests/compile/test_mode.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from pytensor.graph.features import NoOutputFromInplace
1414
from pytensor.graph.rewriting.db import RewriteDatabaseQuery, SequenceDB
1515
from pytensor.link.basic import LocalLinker
16+
from pytensor.link.jax import JAXLinker
1617
from pytensor.tensor.math import dot, tanh
1718
from pytensor.tensor.type import matrix, vector
1819

@@ -142,3 +143,15 @@ class MyLinker(LocalLinker):
142143
test_mode = Mode(linker=MyLinker())
143144
with pytest.raises(Exception):
144145
get_target_language(test_mode)
146+
147+
148+
def test_predefined_modes_respected():
149+
default_mode = get_default_mode()
150+
assert not isinstance(default_mode.linker, JAXLinker)
151+
152+
with config.change_flags(mode="JAX"):
153+
jax_mode = get_default_mode()
154+
assert isinstance(jax_mode.linker, JAXLinker)
155+
156+
default_mode_again = get_default_mode()
157+
assert not isinstance(default_mode_again.linker, JAXLinker)

0 commit comments

Comments
 (0)