Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion pytensor/compile/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
PrintCurrentFunctionGraph,
get_default_mode,
get_mode,
instantiated_default_mode,
local_useless,
optdb,
predefined_linkers,
Expand Down
78 changes: 37 additions & 41 deletions pytensor/compile/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,58 +492,54 @@
"PYTORCH": PYTORCH,
}

instantiated_default_mode = None
_CACHED_RUNTIME_MODES: dict[str, Mode] = {}


def get_mode(orig_string):
if orig_string is None:
string = config.mode
else:
string = orig_string

if not isinstance(string, str):
return string # it is hopefully already a mode...

global instantiated_default_mode
# The default mode is cached. However, config.mode can change
# If instantiated_default_mode has the right class, use it.
if orig_string is None and instantiated_default_mode:
if string in predefined_modes:
default_mode_class = predefined_modes[string].__class__.__name__
else:
default_mode_class = string
if instantiated_default_mode.__class__.__name__ == default_mode_class:
return instantiated_default_mode

if string in ("Mode", "DebugMode", "NanGuardMode"):
if string == "DebugMode":
# need to import later to break circular dependency.
from .debugmode import DebugMode

# DebugMode use its own linker.
ret = DebugMode(optimizer=config.optimizer)
elif string == "NanGuardMode":
# need to import later to break circular dependency.
from .nanguardmode import NanGuardMode

# NanGuardMode use its own linker.
ret = NanGuardMode(True, True, True, optimizer=config.optimizer)
else:
# TODO: Can't we look up the name and invoke it rather than using eval here?
ret = eval(string + "(linker=config.linker, optimizer=config.optimizer)")
elif string in predefined_modes:
ret = predefined_modes[string]
else:
raise Exception(f"No predefined mode exist for string: {string}")
# Keep the original string for error messages
upper_string = string.upper()

if orig_string is None:
# Build and cache the default mode
if config.optimizer_excluding:
ret = ret.excluding(*config.optimizer_excluding.split(":"))
if config.optimizer_including:
ret = ret.including(*config.optimizer_including.split(":"))
if config.optimizer_requiring:
ret = ret.requiring(*config.optimizer_requiring.split(":"))
instantiated_default_mode = ret
if upper_string in predefined_modes:
return predefined_modes[upper_string]

global _CACHED_RUNTIME_MODES

if upper_string in _CACHED_RUNTIME_MODES:
return _CACHED_RUNTIME_MODES[upper_string]

# Need to define the mode for the first time
if upper_string == "MODE":
ret = Mode(linker=config.linker, optimizer=config.optimizer)
elif upper_string in ("DEBUGMODE", "DEBUG_MODE"):
from pytensor.compile.debugmode import DebugMode

# DebugMode use its own linker.
ret = DebugMode(optimizer=config.optimizer)
elif upper_string == "NANGUARDMODE":
from pytensor.compile.nanguardmode import NanGuardMode

Check warning on line 527 in pytensor/compile/mode.py

View check run for this annotation

Codecov / codecov/patch

pytensor/compile/mode.py#L527

Added line #L527 was not covered by tests

# NanGuardMode use its own linker.
ret = NanGuardMode(True, True, True, optimizer=config.optimizer)

Check warning on line 530 in pytensor/compile/mode.py

View check run for this annotation

Codecov / codecov/patch

pytensor/compile/mode.py#L530

Added line #L530 was not covered by tests

else:
raise ValueError(f"No predefined mode exist for string: {string}")

Check warning on line 533 in pytensor/compile/mode.py

View check run for this annotation

Codecov / codecov/patch

pytensor/compile/mode.py#L533

Added line #L533 was not covered by tests

if config.optimizer_excluding:
ret = ret.excluding(*config.optimizer_excluding.split(":"))

Check warning on line 536 in pytensor/compile/mode.py

View check run for this annotation

Codecov / codecov/patch

pytensor/compile/mode.py#L536

Added line #L536 was not covered by tests
if config.optimizer_including:
ret = ret.including(*config.optimizer_including.split(":"))

Check warning on line 538 in pytensor/compile/mode.py

View check run for this annotation

Codecov / codecov/patch

pytensor/compile/mode.py#L538

Added line #L538 was not covered by tests
if config.optimizer_requiring:
ret = ret.requiring(*config.optimizer_requiring.split(":"))

Check warning on line 540 in pytensor/compile/mode.py

View check run for this annotation

Codecov / codecov/patch

pytensor/compile/mode.py#L540

Added line #L540 was not covered by tests
# Cache the mode for next time
_CACHED_RUNTIME_MODES[upper_string] = ret

return ret

Expand Down
4 changes: 3 additions & 1 deletion pytensor/configdefaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,8 @@ def add_compile_configvars():
config.add(
"linker",
"Default linker used if the pytensor flags mode is Mode",
EnumStr("cvm", linker_options),
# Not mutable because the default mode is cached after the first use.
EnumStr("cvm", linker_options, mutable=False),
in_c_key=False,
)

Expand All @@ -410,6 +411,7 @@ def add_compile_configvars():
EnumStr(
"o4",
["o3", "o2", "o1", "unsafe", "fast_run", "fast_compile", "merge", "None"],
mutable=False, # Not mutable because the default mode is cached after the first use.
),
in_c_key=False,
)
Expand Down
8 changes: 1 addition & 7 deletions tests/compile/function/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1105,14 +1105,10 @@ def test_optimizations_preserved(self):
((a.T.T) * (dot(xm, (sm.T.T.T)) + x).T * (x / x) + s),
)
old_default_mode = config.mode
old_default_opt = config.optimizer
old_default_link = config.linker
try:
try:
str_f = pickle.dumps(f, protocol=-1)
config.mode = "Mode"
config.linker = "py"
config.optimizer = "None"
config.mode = "NUMBA"
g = pickle.loads(str_f)
# print g.maker.mode
# print compile.mode.default_mode
Expand All @@ -1121,8 +1117,6 @@ def test_optimizations_preserved(self):
g = "ok"
finally:
config.mode = old_default_mode
config.optimizer = old_default_opt
config.linker = old_default_link

if g == "ok":
return
Expand Down
13 changes: 13 additions & 0 deletions tests/compile/test_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from pytensor.graph.features import NoOutputFromInplace
from pytensor.graph.rewriting.db import RewriteDatabaseQuery, SequenceDB
from pytensor.link.basic import LocalLinker
from pytensor.link.jax import JAXLinker
from pytensor.tensor.math import dot, tanh
from pytensor.tensor.type import matrix, vector

Expand Down Expand Up @@ -142,3 +143,15 @@ class MyLinker(LocalLinker):
test_mode = Mode(linker=MyLinker())
with pytest.raises(Exception):
get_target_language(test_mode)


def test_predefined_modes_respected():
default_mode = get_default_mode()
assert not isinstance(default_mode.linker, JAXLinker)

with config.change_flags(mode="JAX"):
jax_mode = get_default_mode()
assert isinstance(jax_mode.linker, JAXLinker)

default_mode_again = get_default_mode()
assert not isinstance(default_mode_again.linker, JAXLinker)
Loading