1717from collections .abc import Callable , Sequence , Hashable
1818from contextlib import contextmanager
1919from functools import partial
20- import inspect
2120import itertools as it
2221import operator as op
2322from typing import Any , NamedTuple , Union
4645 InputType , OutputType , get_referent , JaxprEqnContext )
4746from jax ._src .state .types import AbstractRef
4847from jax ._src .tree_util import (PyTreeDef , treedef_tuple , tree_unflatten ,
49- tree_flatten , tree_structure , KeyPath , generate_key_paths ,
48+ tree_flatten , tree_structure , generate_key_paths ,
5049 keystr )
5150from jax ._src .util import (unzip2 , safe_zip , safe_map , toposort , split_list ,
5251 merge_lists , partition_list , OrderedSet ,
@@ -1529,8 +1528,7 @@ class DynamicJaxprTracer(core.Tracer):
15291528 def __init__ (self , trace , aval , line_info = None ):
15301529 self ._trace = trace
15311530 self ._line_info = line_info
1532- # Needed for UnexpectedTracerError.
1533- self ._debug_info = self ._trace .frame .debug_info
1531+ self ._debug_info = self ._trace .frame .debug_info # for UnexpectedTracerError
15341532 self .aval = aval
15351533
15361534 def full_lower (self ):
@@ -1551,11 +1549,11 @@ def _origin_msg(self):
15511549
15521550 origin = ("The error occurred while tracing the function "
15531551 f"{ dbg .func_src_info or '<unknown>' } for { dbg .traced_for } . " )
1554- arg_info = arg_info_all ( dbg )
1555- # TODO(mattjj): figure out when not (invar_pos < len(arg_info))
1556- if invar_pos and arg_info and all ( i < len ( arg_info ) for i in invar_pos ):
1557- arg_info = [ arg_info [ i ] for i in invar_pos ]
1558- arg_names = [ f' { name } { keystr ( path ) } ' for name , path in arg_info ]
1552+ if invar_pos and dbg . arg_names :
1553+ try :
1554+ arg_names = [ dbg . arg_names [ i ] for i in invar_pos ]
1555+ except IndexError :
1556+ return "" # TODO(mattjj): figure out when not (invar_pos < len( arg_info))
15591557 if len (arg_names ) == 1 :
15601558 arg_info_str = f"the argument { arg_names [0 ]} "
15611559 elif len (arg_names ) == 2 :
@@ -1632,7 +1630,7 @@ class JaxprStackFrame:
16321630 attrs_tracked : list [tuple [Any , str ]]
16331631 attrs_inits : list
16341632 attrs_vars : list [Var ]
1635- debug_info : DebugInfo | None
1633+ debug_info : lu . TracingDebugInfo | None
16361634
16371635 def __init__ (self ):
16381636 self .gensym = core .gensym ()
@@ -2116,64 +2114,42 @@ def _jvp_jaxpr_zeros(f, store, in_zeros, zero_avals, *primal_tangent_avals):
21162114 store .store (out_zeros )
21172115 return [* out_primals , * out_nz_tangents ]
21182116
2119- # TODO(mattjj): remove this DebugInfo and helper functions, replace with
2120- # api_util.py versions
2121-
2122- class DebugInfo (NamedTuple ):
2123- func_src_info : str | None # f'{fun.__name__} at {filename}:{lineno}'
2124- signature : inspect .Signature | None # inspect.signature(fun)
2125- in_tree : PyTreeDef | None # caller/constructor might not have this info
2126- out_tree : Callable [[], PyTreeDef ] | None # lazy, not avail at trace time
2127- has_kwargs : bool # whether in_tree corresponds to (args, kwargs) or args
2128- traced_for : str # "jit", "scan", "make_jaxpr", etc
2129-
2130- def debug_info (fn : Callable , in_tree : PyTreeDef | None ,
2131- out_tree_thunk : Callable [[], PyTreeDef ] | None ,
2132- has_kwargs : bool , traced_for : str ) -> DebugInfo :
2133- sig = api_util .fun_signature (fn )
2117+ # Callers should be using linear_util.debug_info instead!
2118+ def debug_info (
2119+ fn : Callable ,
2120+ in_tree : PyTreeDef | None ,
2121+ out_tree_thunk : Callable [[], PyTreeDef ] | None ,
2122+ has_kwargs : bool ,
2123+ traced_for : str
2124+ ) -> lu .TracingDebugInfo | None :
21342125 src_info = fun_sourceinfo (fn )
2135- return DebugInfo (src_info , sig , in_tree , out_tree_thunk , has_kwargs ,
2136- traced_for )
2137-
2138- def debug_info_final (fn : lu .WrappedFun , traced_for : str ) -> DebugInfo :
2139- "Make a DebugInfo from data available to final-style primitives like pmap."
2140- in_tree , out_tree , has_kws = flattened_fun_in_tree (fn ) or (None , None , False )
2141- return debug_info (fn .f , in_tree , out_tree , has_kws , traced_for )
2142-
2143- def arg_info_all (dbg : DebugInfo ) -> list [tuple [str , KeyPath ]] | None :
2144- ba = None if dbg .in_tree is None else sig_info (dbg )
2145- if ba is None : return None
2146- return [(name , key_path ) for name , dummy_arg in ba .arguments .items ()
2147- for key_path , _ in generate_key_paths (dummy_arg )]
2148-
2149- def sig_info (dbg : DebugInfo ) -> inspect .BoundArguments | None :
2150- if dbg .in_tree is None or dbg .signature is None : return None
2151- try :
2152- dummy_args = tree_unflatten (dbg .in_tree , [False ] * dbg .in_tree .num_leaves )
2153- except :
2154- return None
2155- args , kwargs = dummy_args if dbg .has_kwargs else (dummy_args , {})
2156- try :
2157- return dbg .signature .bind (* args , ** kwargs )
2158- except (TypeError , ValueError ):
2159- return None
2160-
2161- def result_info (dbg : DebugInfo ) -> list [KeyPath ] | None :
2162- if dbg .out_tree is None : return None
21632126 try :
2164- num_leaves = dbg .out_tree ().num_leaves
2165- dummy_result = tree_unflatten (dbg .out_tree (), [False ] * num_leaves )
2127+ dummy_args = tree_unflatten (in_tree , [False ] * in_tree .num_leaves ) # type: ignore
2128+ args , kwargs = dummy_args if has_kwargs else (dummy_args , {})
2129+ ba = api_util .fun_signature (fn ).bind (* args , ** kwargs ) # type: ignore
2130+ arg_names = tuple (f'{ name } { keystr (path )} ' for name , dummy in ba .arguments .items ()
2131+ for path , _ in generate_key_paths (dummy ))
21662132 except :
2167- return None
2168- else :
2169- return [path for path , _ in generate_key_paths (dummy_result )]
2133+ arg_names = None
2134+ def result_paths ():
2135+ try :
2136+ out_tree = out_tree_thunk ()
2137+ dummy_result = tree_unflatten (out_tree , [False ] * out_tree .num_leaves )
2138+ except :
2139+ return None
2140+ return tuple (path for path , _ in generate_key_paths (dummy_result ))
2141+ return lu .TracingDebugInfo (traced_for , src_info , arg_names , result_paths ) # type: ignore
2142+
2143+ def debug_info_final (fn : lu .WrappedFun , traced_for : str ) -> lu .TracingDebugInfo | None :
2144+ in_tree , out_tree , has_kws = flattened_fun_in_tree (fn ) or (None , None , False )
2145+ return debug_info (fn .f , in_tree , out_tree , has_kws , traced_for )
21702146
21712147
21722148@profiler .annotate_function
21732149def trace_to_jaxpr_dynamic (
21742150 fun : lu .WrappedFun ,
21752151 in_avals : Sequence [AbstractValue ],
2176- debug_info : DebugInfo | None = None ,
2152+ debug_info : lu . TracingDebugInfo | None = None ,
21772153 * ,
21782154 keep_inputs : list [bool ] | None = None ,
21792155) -> tuple [Jaxpr , list [AbstractValue ], list [Any ],
@@ -2197,7 +2173,7 @@ def trace_to_jaxpr_dynamic(
21972173
21982174@profiler .annotate_function
21992175def trace_to_jaxpr_dynamic2 (
2200- fun : lu .WrappedFun , debug_info : DebugInfo | None = None
2176+ fun : lu .WrappedFun , debug_info : lu . TracingDebugInfo | None = None
22012177 ) -> tuple [Jaxpr , OutputType , list [Any ]]:
22022178
22032179 trace = DynamicJaxprTrace ()
0 commit comments