Skip to content

Commit cddf588

Browse files
committed
Respect predefined modes in get_default_mode
Also allow arbitrary capitalization of the modes. Also make linker and optimizer non-mutable config as the mode is cached after using them for the first time.
1 parent 450efff commit cddf588

File tree

5 files changed

+54
-50
lines changed

5 files changed

+54
-50
lines changed

pytensor/compile/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
PrintCurrentFunctionGraph,
3838
get_default_mode,
3939
get_mode,
40-
instantiated_default_mode,
4140
local_useless,
4241
optdb,
4342
predefined_linkers,

pytensor/compile/mode.py

Lines changed: 37 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -492,58 +492,54 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
492492
"PYTORCH": PYTORCH,
493493
}
494494

495-
instantiated_default_mode = None
495+
_CACHED_RUNTIME_MODES: dict[str, Mode] = {}
496496

497497

498498
def get_mode(orig_string):
499499
if orig_string is None:
500500
string = config.mode
501501
else:
502502
string = orig_string
503+
503504
if not isinstance(string, str):
504505
return string # it is hopefully already a mode...
505506

506-
global instantiated_default_mode
507-
# The default mode is cached. However, config.mode can change
508-
# If instantiated_default_mode has the right class, use it.
509-
if orig_string is None and instantiated_default_mode:
510-
if string in predefined_modes:
511-
default_mode_class = predefined_modes[string].__class__.__name__
512-
else:
513-
default_mode_class = string
514-
if instantiated_default_mode.__class__.__name__ == default_mode_class:
515-
return instantiated_default_mode
516-
517-
if string in ("Mode", "DebugMode", "NanGuardMode"):
518-
if string == "DebugMode":
519-
# need to import later to break circular dependency.
520-
from .debugmode import DebugMode
521-
522-
# DebugMode use its own linker.
523-
ret = DebugMode(optimizer=config.optimizer)
524-
elif string == "NanGuardMode":
525-
# need to import later to break circular dependency.
526-
from .nanguardmode import NanGuardMode
527-
528-
# NanGuardMode use its own linker.
529-
ret = NanGuardMode(True, True, True, optimizer=config.optimizer)
530-
else:
531-
# TODO: Can't we look up the name and invoke it rather than using eval here?
532-
ret = eval(string + "(linker=config.linker, optimizer=config.optimizer)")
533-
elif string in predefined_modes:
534-
ret = predefined_modes[string]
535-
else:
536-
raise Exception(f"No predefined mode exist for string: {string}")
507+
# Keep the original string for error messages
508+
upper_string = string.upper()
537509

538-
if orig_string is None:
539-
# Build and cache the default mode
540-
if config.optimizer_excluding:
541-
ret = ret.excluding(*config.optimizer_excluding.split(":"))
542-
if config.optimizer_including:
543-
ret = ret.including(*config.optimizer_including.split(":"))
544-
if config.optimizer_requiring:
545-
ret = ret.requiring(*config.optimizer_requiring.split(":"))
546-
instantiated_default_mode = ret
510+
if upper_string in predefined_modes:
511+
return predefined_modes[upper_string]
512+
513+
global _CACHED_RUNTIME_MODES
514+
515+
if upper_string in _CACHED_RUNTIME_MODES:
516+
return _CACHED_RUNTIME_MODES[upper_string]
517+
518+
# Need to define the mode for the first time
519+
if upper_string == "MODE":
520+
ret = Mode(linker=config.linker, optimizer=config.optimizer)
521+
elif upper_string in ("DEBUGMODE", "DEBUG_MODE"):
522+
from pytensor.compile.debugmode import DebugMode
523+
524+
# DebugMode use its own linker.
525+
ret = DebugMode(optimizer=config.optimizer)
526+
elif upper_string == "NANGUARDMODE":
527+
from pytensor.compile.nanguardmode import NanGuardMode
528+
529+
# NanGuardMode use its own linker.
530+
ret = NanGuardMode(True, True, True, optimizer=config.optimizer)
531+
532+
else:
533+
raise ValueError(f"No predefined mode exist for string: {string}")
534+
535+
if config.optimizer_excluding:
536+
ret = ret.excluding(*config.optimizer_excluding.split(":"))
537+
if config.optimizer_including:
538+
ret = ret.including(*config.optimizer_including.split(":"))
539+
if config.optimizer_requiring:
540+
ret = ret.requiring(*config.optimizer_requiring.split(":"))
541+
# Cache the mode for next time
542+
_CACHED_RUNTIME_MODES[upper_string] = ret
547543

548544
return ret
549545

pytensor/configdefaults.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,8 @@ def add_compile_configvars():
387387
config.add(
388388
"linker",
389389
"Default linker used if the pytensor flags mode is Mode",
390-
EnumStr("cvm", linker_options),
390+
# Not mutable because the default mode is cached after the first use.
391+
EnumStr("cvm", linker_options, mutable=False),
391392
in_c_key=False,
392393
)
393394

@@ -410,6 +411,7 @@ def add_compile_configvars():
410411
EnumStr(
411412
"o4",
412413
["o3", "o2", "o1", "unsafe", "fast_run", "fast_compile", "merge", "None"],
414+
mutable=False, # Not mutable because the default mode is cached after the first use.
413415
),
414416
in_c_key=False,
415417
)

tests/compile/function/test_types.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1105,14 +1105,10 @@ def test_optimizations_preserved(self):
11051105
((a.T.T) * (dot(xm, (sm.T.T.T)) + x).T * (x / x) + s),
11061106
)
11071107
old_default_mode = config.mode
1108-
old_default_opt = config.optimizer
1109-
old_default_link = config.linker
11101108
try:
11111109
try:
11121110
str_f = pickle.dumps(f, protocol=-1)
1113-
config.mode = "Mode"
1114-
config.linker = "py"
1115-
config.optimizer = "None"
1111+
config.mode = "NUMBA"
11161112
g = pickle.loads(str_f)
11171113
# print g.maker.mode
11181114
# print compile.mode.default_mode
@@ -1121,8 +1117,6 @@ def test_optimizations_preserved(self):
11211117
g = "ok"
11221118
finally:
11231119
config.mode = old_default_mode
1124-
config.optimizer = old_default_opt
1125-
config.linker = old_default_link
11261120

11271121
if g == "ok":
11281122
return

tests/compile/test_mode.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from pytensor.graph.features import NoOutputFromInplace
1414
from pytensor.graph.rewriting.db import RewriteDatabaseQuery, SequenceDB
1515
from pytensor.link.basic import LocalLinker
16+
from pytensor.link.jax import JAXLinker
1617
from pytensor.tensor.math import dot, tanh
1718
from pytensor.tensor.type import matrix, vector
1819

@@ -142,3 +143,15 @@ class MyLinker(LocalLinker):
142143
test_mode = Mode(linker=MyLinker())
143144
with pytest.raises(Exception):
144145
get_target_language(test_mode)
146+
147+
148+
def test_predefined_modes_respected():
149+
default_mode = get_default_mode()
150+
assert not isinstance(default_mode.linker, JAXLinker)
151+
152+
with config.change_flags(mode="JAX"):
153+
jax_mode = get_default_mode()
154+
assert isinstance(jax_mode.linker, JAXLinker)
155+
156+
default_mode_again = get_default_mode()
157+
assert not isinstance(default_mode_again.linker, JAXLinker)

0 commit comments

Comments
 (0)