@@ -880,6 +880,21 @@ def _restore_defaults(self):
880
880
value = value .storage [0 ]
881
881
self [i ] = value
882
882
883
+ def add_note_to_invalid_argument_exception (self , e , arg_container , arg ):
884
+ i = self .input_storage .index (arg_container )
885
+ function_name = (
886
+ f"PyTensor function '{ self .name } '" if self .name else "PyTensor function"
887
+ )
888
+ argument_name = (
889
+ f"argument '{ arg .name } '" if getattr (arg , "name" , None ) else "argument"
890
+ )
891
+ where = (
892
+ ""
893
+ if config .exception_verbosity == "low"
894
+ else get_variable_trace_string (self .maker .inputs [i ].variable )
895
+ )
896
+ e .add_note (f"\n Invalid { argument_name } to { function_name } at index { i } .{ where } " )
897
+
883
898
def __call__ (self , * args , output_subset = None , ** kwargs ):
884
899
"""
885
900
Evaluates value of a function on given arguments.
@@ -947,35 +962,11 @@ def __call__(self, *args, output_subset=None, **kwargs):
947
962
strict = arg_container .strict ,
948
963
allow_downcast = arg_container .allow_downcast ,
949
964
)
950
-
951
965
except Exception as e :
952
- i = input_storage .index (arg_container )
953
- function_name = "pytensor function"
954
- argument_name = "argument"
955
- if self .name :
956
- function_name += ' with name "' + self .name + '"'
957
- if hasattr (arg , "name" ) and arg .name :
958
- argument_name += ' with name "' + arg .name + '"'
959
- where = get_variable_trace_string (self .maker .inputs [i ].variable )
960
- if len (e .args ) == 1 :
961
- e .args = (
962
- "Bad input "
963
- + argument_name
964
- + " to "
965
- + function_name
966
- + f" at index { int (i )} (0-based). { where } "
967
- + e .args [0 ],
968
- )
969
- else :
970
- e .args = (
971
- "Bad input "
972
- + argument_name
973
- + " to "
974
- + function_name
975
- + f" at index { int (i )} (0-based). { where } "
976
- ) + e .args
977
- self ._restore_defaults ()
978
- raise
966
+ self .add_note_to_invalid_argument_exception (
967
+ e , arg_container , arg
968
+ )
969
+ raise e
979
970
arg_container .provided += 1
980
971
981
972
# Set keyword arguments
0 commit comments