Skip to content

Commit f67b3bd

Browse files
committed
Make exceptions less verbose by default
1 parent 92c3b49 commit f67b3bd

File tree

3 files changed

+26
-36
lines changed

3 files changed

+26
-36
lines changed

pytensor/compile/function/types.py

Lines changed: 18 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -880,6 +880,21 @@ def _restore_defaults(self):
880880
value = value.storage[0]
881881
self[i] = value
882882

883+
def add_note_to_invalid_argument_exception(self, e, arg_container, arg):
884+
i = self.input_storage.index(arg_container)
885+
function_name = (
886+
f"PyTensor function '{self.name}'" if self.name else "PyTensor function"
887+
)
888+
argument_name = (
889+
f"argument '{arg.name}'" if getattr(arg, "name", None) else "argument"
890+
)
891+
where = (
892+
""
893+
if config.exception_verbosity == "low"
894+
else get_variable_trace_string(self.maker.inputs[i].variable)
895+
)
896+
e.add_note(f"\nInvalid {argument_name} to {function_name} at index {i}.{where}")
897+
883898
def __call__(self, *args, output_subset=None, **kwargs):
884899
"""
885900
Evaluates value of a function on given arguments.
@@ -947,34 +962,10 @@ def __call__(self, *args, output_subset=None, **kwargs):
947962
strict=arg_container.strict,
948963
allow_downcast=arg_container.allow_downcast,
949964
)
950-
951965
except Exception as e:
952-
i = input_storage.index(arg_container)
953-
function_name = "pytensor function"
954-
argument_name = "argument"
955-
if self.name:
956-
function_name += ' with name "' + self.name + '"'
957-
if hasattr(arg, "name") and arg.name:
958-
argument_name += ' with name "' + arg.name + '"'
959-
where = get_variable_trace_string(self.maker.inputs[i].variable)
960-
if len(e.args) == 1:
961-
e.args = (
962-
"Bad input "
963-
+ argument_name
964-
+ " to "
965-
+ function_name
966-
+ f" at index {int(i)} (0-based). {where}"
967-
+ e.args[0],
968-
)
969-
else:
970-
e.args = (
971-
"Bad input "
972-
+ argument_name
973-
+ " to "
974-
+ function_name
975-
+ f" at index {int(i)} (0-based). {where}"
976-
) + e.args
977-
self._restore_defaults()
966+
self.add_note_to_invalid_argument_exception(
967+
e, arg_container, arg
968+
)
978969
raise
979970
arg_container.provided += 1
980971

pytensor/configdefaults.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -644,15 +644,8 @@ def add_error_and_warning_configvars():
644644
# on all important apply nodes.
645645
config.add(
646646
"exception_verbosity",
647-
"If 'low', the text of exceptions will generally refer "
648-
"to apply nodes with short names such as "
649-
"Elemwise{add_no_inplace}. If 'high', some exceptions "
650-
"will also refer to apply nodes with long descriptions "
651-
""" like:
652-
A. Elemwise{add_no_inplace}
653-
B. log_likelihood_v_given_h
654-
C. log_likelihood_h""",
655-
EnumStr("low", ["high"]),
647+
"Verbosity of exceptions generated by PyTensor functions.",
648+
EnumStr("low", ["medium", "high"]),
656649
in_c_key=False,
657650
)
658651

pytensor/link/utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,12 @@ def raise_with_op(
313313
# print a simple traceback from KeyboardInterrupt
314314
raise exc_value.with_traceback(exc_trace)
315315

316+
if verbosity == "low":
317+
exc_value.add_note(
318+
"\nHINT: Set PyTensor `config.exception_verbosity` to `medium` or `high` for more information about the source of the error."
319+
)
320+
raise exc_value.with_traceback(exc_trace)
321+
316322
trace = getattr(node.outputs[0].tag, "trace", ())
317323

318324
exc_value.__thunk_trace__ = trace

0 commit comments

Comments
 (0)