Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 0 additions & 17 deletions pytensor/compile/function/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,23 +275,6 @@ def opt_log1p(node):
else:
output_keys = None

if name is None:
# Determine possible file names
source_file = re.sub(r"\.pyc?", ".py", __file__)
compiled_file = source_file + "c"

stack = tb.extract_stack()
idx = len(stack) - 1

last_frame = stack[idx]
if last_frame[0] == source_file or last_frame[0] == compiled_file:
func_frame = stack[idx - 1]
while "pytensor/graph" in func_frame[0] and idx > 0:
idx -= 1
# This can happen if we call var.eval()
func_frame = stack[idx - 1]
name = func_frame[0] + ":" + str(func_frame[1])

if updates is None:
updates = []

Expand Down
45 changes: 18 additions & 27 deletions pytensor/compile/function/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -880,6 +880,21 @@ def _restore_defaults(self):
value = value.storage[0]
self[i] = value

def add_note_to_invalid_argument_exception(self, e, arg_container, arg):
i = self.input_storage.index(arg_container)
function_name = (
f"PyTensor function '{self.name}'" if self.name else "PyTensor function"
)
argument_name = (
f"argument '{arg.name}'" if getattr(arg, "name", None) else "argument"
)
where = (
""
if config.exception_verbosity == "low"
else get_variable_trace_string(self.maker.inputs[i].variable)
)
e.add_note(f"\nInvalid {argument_name} to {function_name} at index {i}.{where}")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we grab the name of the symbolic input and show that as well (if its not None)?


def __call__(self, *args, output_subset=None, **kwargs):
"""
Evaluates value of a function on given arguments.
Expand Down Expand Up @@ -947,34 +962,10 @@ def __call__(self, *args, output_subset=None, **kwargs):
strict=arg_container.strict,
allow_downcast=arg_container.allow_downcast,
)

except Exception as e:
i = input_storage.index(arg_container)
function_name = "pytensor function"
argument_name = "argument"
if self.name:
function_name += ' with name "' + self.name + '"'
if hasattr(arg, "name") and arg.name:
argument_name += ' with name "' + arg.name + '"'
where = get_variable_trace_string(self.maker.inputs[i].variable)
if len(e.args) == 1:
e.args = (
"Bad input "
+ argument_name
+ " to "
+ function_name
+ f" at index {int(i)} (0-based). {where}"
+ e.args[0],
)
else:
e.args = (
"Bad input "
+ argument_name
+ " to "
+ function_name
+ f" at index {int(i)} (0-based). {where}"
) + e.args
self._restore_defaults()
self.add_note_to_invalid_argument_exception(
e, arg_container, arg
)
raise
arg_container.provided += 1

Expand Down
11 changes: 2 additions & 9 deletions pytensor/configdefaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,15 +644,8 @@ def add_error_and_warning_configvars():
# on all important apply nodes.
config.add(
"exception_verbosity",
"If 'low', the text of exceptions will generally refer "
"to apply nodes with short names such as "
"Elemwise{add_no_inplace}. If 'high', some exceptions "
"will also refer to apply nodes with long descriptions "
""" like:
A. Elemwise{add_no_inplace}
B. log_likelihood_v_given_h
C. log_likelihood_h""",
EnumStr("low", ["high"]),
"Verbosity of exceptions generated by PyTensor functions.",
EnumStr("low", ["medium", "high"]),
in_c_key=False,
)

Expand Down
6 changes: 6 additions & 0 deletions pytensor/link/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,12 @@ def raise_with_op(
# print a simple traceback from KeyboardInterrupt
raise exc_value.with_traceback(exc_trace)

if verbosity == "low":
exc_value.add_note(
"\nHINT: Set PyTensor `config.exception_verbosity` to `medium` or `high` for more information about the source of the error."
)
raise exc_value.with_traceback(exc_trace)

trace = getattr(node.outputs[0].tag, "trace", ())

exc_value.__thunk_trace__ = trace
Expand Down
5 changes: 3 additions & 2 deletions tests/compile/function/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,9 @@ def test_function_dump():
def test_function_name():
x = vector("x")
func = function([x], x + 1.0)

assert __file__ in func.name
assert func.name is None
func = function([x], x + 1.0, name="my_func")
assert func.name == "my_func"


def test_trust_input():
Expand Down
6 changes: 5 additions & 1 deletion tests/link/test_vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,9 +421,13 @@ def make_thunk(self, *args, **kwargs):

z = BadOp()(a)

with pytest.raises(Exception, match=r".*Apply node that caused the error.*"):
with pytest.raises(Exception, match=r"bad Op"):
function([a], z, mode=Mode(optimizer=None, linker=linker))

with config.change_flags(exception_verbosity="high"):
with pytest.raises(Exception, match=r".*Apply node that caused the error.*"):
function([a], z, mode=Mode(optimizer=None, linker=linker))


def test_VM_exception():
class SomeVM(VM):
Expand Down
Loading