Skip to content

Commit 3262770

Browse files
Merge pull request #24516 from mattjj:improve-concreteness-error-in-remat-3
PiperOrigin-RevId: 707376209
2 parents f9737b9 + 9acd4a9 commit 3262770

File tree

6 files changed

+63
-77
lines changed

6 files changed

+63
-77
lines changed

jax/_src/ad_checkpoint.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@
3232
from jax._src import effects
3333
from jax._src import source_info_util
3434
from jax._src import traceback_util
35-
from jax._src.api_util import flatten_fun, shaped_abstractify
35+
from jax._src.api_util import (
36+
flatten_fun, shaped_abstractify, debug_info, fun_sourceinfo, fun_signature)
3637
from jax._src.interpreters import ad
3738
from jax._src.interpreters import batching
3839
from jax._src.interpreters import mlir
@@ -41,7 +42,7 @@
4142
from jax._src.lax import convolution as lax_convolution
4243
from jax._src.lib.mlir.dialects import hlo
4344
from jax._src.traceback_util import api_boundary
44-
from jax._src.tree_util import tree_flatten, tree_unflatten, tree_structure, keystr
45+
from jax._src.tree_util import tree_flatten, tree_unflatten, tree_structure
4546
from jax._src.util import (unzip2, wraps, split_list, partition_list, safe_map,
4647
safe_zip, merge_lists, weakref_lru_cache)
4748

@@ -320,10 +321,12 @@ def foo(x, y):
320321
@wraps(fun)
321322
@api_boundary
322323
def fun_remat(*args, **kwargs):
324+
debug = debug_info("checkpoint / remat", fun_sourceinfo(fun),
325+
fun_signature(fun), args, kwargs, static_argnums, ())
323326
fun_, args = _remat_static_argnums(fun, static_argnums, args)
324327
args_flat, in_tree = tree_flatten((args, kwargs))
325328
in_avals = [shaped_abstractify(x) for x in args_flat]
326-
jaxpr, consts, out_tree = _trace_to_jaxpr(fun_, in_tree, tuple(in_avals))
329+
jaxpr, consts, out_tree = _trace_to_jaxpr(fun_, in_tree, tuple(in_avals), debug)
327330
out_flat = remat_p.bind(
328331
*consts, *args_flat, jaxpr=jaxpr, prevent_cse=prevent_cse,
329332
differentiated=False, policy=policy)
@@ -409,9 +412,8 @@ def new_fun(*dyn_args, **kwargs):
409412
# This helper is similar to those in control_flow/common.py, but with
410413
# remat-specific errors.
411414
@weakref_lru_cache
412-
def _trace_to_jaxpr(fun, in_tree, in_avals):
415+
def _trace_to_jaxpr(fun, in_tree, in_avals, debug):
413416
flat_fun, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
414-
debug = pe.debug_info(fun, in_tree, out_tree, True, "checkpoint")
415417
try:
416418
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
417419
except core.ConcretizationTypeError as e:
@@ -445,10 +447,9 @@ def f_(*args):
445447
out_tree = lambda: tree_structure(out_shape)
446448
assert len(jaxpr.invars) == len(in_leaves)
447449
dbg = pe.debug_info(f, in_tree, out_tree, True, "saved_residuals")
448-
arg_info = pe.arg_info_all(dbg)
449-
return _saved_residuals(jaxpr, arg_info)
450+
return _saved_residuals(jaxpr, dbg.arg_names) # type: ignore
450451

451-
def _saved_residuals(jaxpr, arg_info) -> list[tuple[core.AbstractValue, str]]:
452+
def _saved_residuals(jaxpr, arg_names) -> list[tuple[core.AbstractValue, str]]:
452453
res_lits = [x for x in jaxpr.outvars if isinstance(x, core.Literal)]
453454
res_vars = {x for x in jaxpr.outvars if not isinstance(x, core.Literal)}
454455

@@ -463,9 +464,8 @@ def _saved_residuals(jaxpr, arg_info) -> list[tuple[core.AbstractValue, str]]:
463464

464465
for i, v in enumerate(jaxpr.invars):
465466
if v in res_vars:
466-
if arg_info is not None:
467-
arg_name, arg_path = arg_info[i]
468-
src = f'from the argument {arg_name}{keystr(arg_path)}'
467+
if arg_names is not None:
468+
src = f'from the argument {arg_names[i]}'
469469
else:
470470
src = 'from the argument at flattened index {i}'
471471
results.append((v.aval, src))

jax/_src/interpreters/partial_eval.py

Lines changed: 36 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from collections.abc import Callable, Sequence, Hashable
1818
from contextlib import contextmanager
1919
from functools import partial
20-
import inspect
2120
import itertools as it
2221
import operator as op
2322
from typing import Any, NamedTuple, Union
@@ -46,7 +45,7 @@
4645
InputType, OutputType, get_referent, JaxprEqnContext)
4746
from jax._src.state.types import AbstractRef
4847
from jax._src.tree_util import (PyTreeDef, treedef_tuple, tree_unflatten,
49-
tree_flatten, tree_structure, KeyPath, generate_key_paths,
48+
tree_flatten, tree_structure, generate_key_paths,
5049
keystr)
5150
from jax._src.util import (unzip2, safe_zip, safe_map, toposort, split_list,
5251
merge_lists, partition_list, OrderedSet,
@@ -1529,8 +1528,7 @@ class DynamicJaxprTracer(core.Tracer):
15291528
def __init__(self, trace, aval, line_info=None):
15301529
self._trace = trace
15311530
self._line_info = line_info
1532-
# Needed for UnexpectedTracerError.
1533-
self._debug_info = self._trace.frame.debug_info
1531+
self._debug_info = self._trace.frame.debug_info # for UnexpectedTracerError
15341532
self.aval = aval
15351533

15361534
def full_lower(self):
@@ -1551,11 +1549,11 @@ def _origin_msg(self):
15511549

15521550
origin = ("The error occurred while tracing the function "
15531551
f"{dbg.func_src_info or '<unknown>'} for {dbg.traced_for}. ")
1554-
arg_info = arg_info_all(dbg)
1555-
# TODO(mattjj): figure out when not (invar_pos < len(arg_info))
1556-
if invar_pos and arg_info and all(i < len(arg_info) for i in invar_pos):
1557-
arg_info = [arg_info[i] for i in invar_pos]
1558-
arg_names = [f'{name}{keystr(path)}' for name, path in arg_info]
1552+
if invar_pos and dbg.arg_names:
1553+
try:
1554+
arg_names = [dbg.arg_names[i] for i in invar_pos]
1555+
except IndexError:
1556+
return "" # TODO(mattjj): figure out when not (invar_pos < len(arg_info))
15591557
if len(arg_names) == 1:
15601558
arg_info_str = f"the argument {arg_names[0]}"
15611559
elif len(arg_names) == 2:
@@ -1632,7 +1630,7 @@ class JaxprStackFrame:
16321630
attrs_tracked: list[tuple[Any, str]]
16331631
attrs_inits: list
16341632
attrs_vars: list[Var]
1635-
debug_info: DebugInfo | None
1633+
debug_info: lu.TracingDebugInfo | None
16361634

16371635
def __init__(self):
16381636
self.gensym = core.gensym()
@@ -2116,64 +2114,42 @@ def _jvp_jaxpr_zeros(f, store, in_zeros, zero_avals, *primal_tangent_avals):
21162114
store.store(out_zeros)
21172115
return [*out_primals, *out_nz_tangents]
21182116

2119-
# TODO(mattjj): remove this DebugInfo and helper functions, replace with
2120-
# api_util.py versions
2121-
2122-
class DebugInfo(NamedTuple):
2123-
func_src_info: str | None # f'{fun.__name__} at {filename}:{lineno}'
2124-
signature: inspect.Signature | None # inspect.signature(fun)
2125-
in_tree: PyTreeDef | None # caller/constructor might not have this info
2126-
out_tree: Callable[[], PyTreeDef] | None # lazy, not avail at trace time
2127-
has_kwargs: bool # whether in_tree corresponds to (args, kwargs) or args
2128-
traced_for: str # "jit", "scan", "make_jaxpr", etc
2129-
2130-
def debug_info(fn: Callable, in_tree: PyTreeDef | None,
2131-
out_tree_thunk: Callable[[], PyTreeDef] | None,
2132-
has_kwargs: bool, traced_for: str) -> DebugInfo:
2133-
sig = api_util.fun_signature(fn)
2117+
# Callers should be using linear_util.debug_info instead!
2118+
def debug_info(
2119+
fn: Callable,
2120+
in_tree: PyTreeDef | None,
2121+
out_tree_thunk: Callable[[], PyTreeDef] | None,
2122+
has_kwargs: bool,
2123+
traced_for: str
2124+
) -> lu.TracingDebugInfo | None:
21342125
src_info = fun_sourceinfo(fn)
2135-
return DebugInfo(src_info, sig, in_tree, out_tree_thunk, has_kwargs,
2136-
traced_for)
2137-
2138-
def debug_info_final(fn: lu.WrappedFun, traced_for: str) -> DebugInfo:
2139-
"Make a DebugInfo from data available to final-style primitives like pmap."
2140-
in_tree, out_tree, has_kws = flattened_fun_in_tree(fn) or (None, None, False)
2141-
return debug_info(fn.f, in_tree, out_tree, has_kws, traced_for)
2142-
2143-
def arg_info_all(dbg: DebugInfo) -> list[tuple[str, KeyPath]] | None:
2144-
ba = None if dbg.in_tree is None else sig_info(dbg)
2145-
if ba is None: return None
2146-
return [(name, key_path) for name, dummy_arg in ba.arguments.items()
2147-
for key_path, _ in generate_key_paths(dummy_arg)]
2148-
2149-
def sig_info(dbg: DebugInfo) -> inspect.BoundArguments | None:
2150-
if dbg.in_tree is None or dbg.signature is None: return None
2151-
try:
2152-
dummy_args = tree_unflatten(dbg.in_tree, [False] * dbg.in_tree.num_leaves)
2153-
except:
2154-
return None
2155-
args, kwargs = dummy_args if dbg.has_kwargs else (dummy_args, {})
2156-
try:
2157-
return dbg.signature.bind(*args, **kwargs)
2158-
except (TypeError, ValueError):
2159-
return None
2160-
2161-
def result_info(dbg: DebugInfo) -> list[KeyPath] | None:
2162-
if dbg.out_tree is None: return None
21632126
try:
2164-
num_leaves = dbg.out_tree().num_leaves
2165-
dummy_result = tree_unflatten(dbg.out_tree(), [False] * num_leaves)
2127+
dummy_args = tree_unflatten(in_tree, [False] * in_tree.num_leaves) # type: ignore
2128+
args, kwargs = dummy_args if has_kwargs else (dummy_args, {})
2129+
ba = api_util.fun_signature(fn).bind(*args, **kwargs) # type: ignore
2130+
arg_names = tuple(f'{name}{keystr(path)}' for name, dummy in ba.arguments.items()
2131+
for path, _ in generate_key_paths(dummy))
21662132
except:
2167-
return None
2168-
else:
2169-
return [path for path, _ in generate_key_paths(dummy_result)]
2133+
arg_names = None
2134+
def result_paths():
2135+
try:
2136+
out_tree = out_tree_thunk()
2137+
dummy_result = tree_unflatten(out_tree, [False] * out_tree.num_leaves)
2138+
except:
2139+
return None
2140+
return tuple(path for path, _ in generate_key_paths(dummy_result))
2141+
return lu.TracingDebugInfo(traced_for, src_info, arg_names, result_paths) # type: ignore
2142+
2143+
def debug_info_final(fn: lu.WrappedFun, traced_for: str) -> lu.TracingDebugInfo | None:
2144+
in_tree, out_tree, has_kws = flattened_fun_in_tree(fn) or (None, None, False)
2145+
return debug_info(fn.f, in_tree, out_tree, has_kws, traced_for)
21702146

21712147

21722148
@profiler.annotate_function
21732149
def trace_to_jaxpr_dynamic(
21742150
fun: lu.WrappedFun,
21752151
in_avals: Sequence[AbstractValue],
2176-
debug_info: DebugInfo | None = None,
2152+
debug_info: lu.TracingDebugInfo | None = None,
21772153
*,
21782154
keep_inputs: list[bool] | None = None,
21792155
) -> tuple[Jaxpr, list[AbstractValue], list[Any],
@@ -2197,7 +2173,7 @@ def trace_to_jaxpr_dynamic(
21972173

21982174
@profiler.annotate_function
21992175
def trace_to_jaxpr_dynamic2(
2200-
fun: lu.WrappedFun, debug_info: DebugInfo | None = None
2176+
fun: lu.WrappedFun, debug_info: lu.TracingDebugInfo | None = None
22012177
) -> tuple[Jaxpr, OutputType, list[Any]]:
22022178

22032179
trace = DynamicJaxprTrace()

jax/_src/linear_util.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,6 @@ def valid_size(d) -> bool:
295295
class TracingDebugInfo(NamedTuple):
296296
# Packages up trace/staging-time debug info about a func and its parameters,
297297
# formed just before staging to a jaxpr and read in trace-time error messages.
298-
# TODO(mattjj): delete partial_eval.DebugInfo, replace all uses with this cls
299298
traced_for: str # e.g. 'jit', 'scan', etc
300299
func_src_info: str | None # e.g. f'{fun.__name__} at {filename}:{lineno}'
301300
arg_names: tuple[str, ...] # e.g. ('args[0]', ... )

jax/_src/pallas/core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,7 @@ def to_block_mapping(
429429
"pallas_call index_map",
430430
)
431431
index_map_src_info = NameAndSrcInfo.from_pallas_call(
432-
None, debug.func_src_info
432+
None, debug.func_src_info # type: ignore
433433
)
434434
with tracing_grid_env(grid, mapped_dims):
435435
jaxpr, out_avals, consts, () = pe.trace_to_jaxpr_dynamic(

jax/interpreters/partial_eval.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
ConstFoldRule as ConstFoldRule,
2121
ConstVar as ConstVar,
2222
DCERule as DCERule,
23-
DebugInfo as DebugInfo,
2423
DynamicJaxprTrace as DynamicJaxprTrace,
2524
DynamicJaxprTracer as DynamicJaxprTracer,
2625
ForwardingRule as ForwardingRule,
@@ -40,7 +39,6 @@
4039
TracerId as TracerId,
4140
Val as Val,
4241
abstract_eval_fun as abstract_eval_fun,
43-
arg_info_all as arg_info_all,
4442
call_padding_rule as call_padding_rule,
4543
call_param_updaters as call_param_updaters,
4644
call_partial_eval_custom_rule as call_partial_eval_custom_rule,
@@ -79,8 +77,6 @@
7977
partial_eval_wrapper_nounits as partial_eval_wrapper_nounits,
8078
partition_pvals as partition_pvals,
8179
recipe_to_eqn as recipe_to_eqn,
82-
result_info as result_info,
83-
sig_info as sig_info,
8480
trace_to_jaxpr_dynamic as _trace_to_jaxpr_dynamic,
8581
trace_to_jaxpr_dynamic2 as trace_to_jaxpr_dynamic2,
8682
trace_to_jaxpr_nounits as trace_to_jaxpr_nounits,

tests/api_test.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6507,6 +6507,21 @@ def f(x):
65076507
else:
65086508
assert False
65096509

6510+
def test_concreteness_error_includes_user_code_with_static_argnums(self):
6511+
@partial(jax.remat, static_argnums=(1,))
6512+
def f(x, _):
6513+
if x > 0:
6514+
return x
6515+
else:
6516+
return jnp.sin(x)
6517+
6518+
try:
6519+
f(3., 1.)
6520+
except TracerBoolConversionError:
6521+
self.assertIn('x > 0', traceback.format_exc())
6522+
else:
6523+
assert False
6524+
65106525

65116526
@jtu.with_config(jax_pprint_use_color=False)
65126527
class JaxprTest(jtu.JaxTestCase):

0 commit comments

Comments
 (0)