Skip to content

Commit 44a13c9

Browse files
yashk2810jax authors
authored andcommitted
Merge code between make_jaxpr and jit(f).trace.
The semantics of `make_jaxpr` are preserved here i.e. `make_jaxpr` still closes over tracers but `jit(f).trace` doesn't. Since we can keep the existing behavior and still merge the implementation is a good cleanup! Fixes #21116 PiperOrigin-RevId: 641347140
1 parent 25cc84b commit 44a13c9

File tree

5 files changed

+48
-44
lines changed

5 files changed

+48
-44
lines changed

jax/_src/api.py

Lines changed: 22 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,8 @@
7171
from jax._src.layout import Layout, AutoLayout
7272
from jax._src.traceback_util import api_boundary
7373
from jax._src import tree_util
74-
from jax._src.util import unzip2, safe_map, safe_zip, wrap_name, wraps
74+
from jax._src.util import (unzip2, safe_map, safe_zip, wrap_name, wraps,
75+
split_list)
7576
from jax._src import util
7677

7778
from jax._src.interpreters import ad
@@ -2358,43 +2359,34 @@ def make_jaxpr(fun: Callable,
23582359
g:f32[] = mul f c
23592360
in (g,) }
23602361
"""
2361-
check_callable(fun)
2362-
static_argnums = _ensure_index_tuple(static_argnums)
2363-
2364-
def abstractify(args, kwargs):
2365-
flat_args, in_tree = tree_flatten((args, kwargs))
2366-
if abstracted_axes is None:
2367-
return map(shaped_abstractify, flat_args), in_tree, [True] * len(flat_args)
2368-
else:
2369-
axes_specs = _flat_axes_specs(abstracted_axes, *args, **kwargs)
2370-
in_type = pe.infer_lambda_input_type(axes_specs, flat_args)
2371-
in_avals, keep_inputs = unzip2(in_type)
2372-
return in_avals, in_tree, keep_inputs
2362+
try:
2363+
hash(fun)
2364+
weakref.ref(fun)
2365+
except TypeError:
2366+
fun = partial(fun)
23732367

23742368
@wraps(fun)
23752369
@api_boundary
23762370
def make_jaxpr_f(*args, **kwargs):
2377-
f = lu.wrap_init(fun)
2378-
if static_argnums:
2379-
dyn_argnums = [i for i in range(len(args)) if i not in static_argnums]
2380-
f, args = argnums_partial(f, dyn_argnums, args)
2381-
in_avals, in_tree, keep_inputs = abstractify(args, kwargs)
2382-
in_type = tuple(zip(in_avals, keep_inputs))
2383-
f, out_tree = flatten_fun(f, in_tree)
2384-
f = lu.annotate(f, in_type)
2385-
debug_info = pe.debug_info(fun, in_tree, out_tree, True, 'make_jaxpr')
23862371
with ExitStack() as stack:
23872372
for axis_name, size in axis_env or []:
23882373
stack.enter_context(core.extend_axis_env(axis_name, size, None))
2389-
jaxpr, out_type, consts = pe.trace_to_jaxpr_dynamic2(
2390-
f, debug_info=debug_info)
2391-
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
2374+
traced = jit(fun, static_argnums=static_argnums,
2375+
abstracted_axes=abstracted_axes).trace(*args, **kwargs)
2376+
# `jit` converts tracers in consts to args but that breaks the semantics of
2377+
# `make_jaxpr`. Hence convert the tracers in args back to consts in jaxpr.
2378+
if traced._num_consts:
2379+
consts, _ = split_list(traced._args_flat, [traced._num_consts])
2380+
jaxpr_ = pe.convert_invars_to_constvars(traced.jaxpr.jaxpr,
2381+
traced._num_consts)
2382+
jaxpr = core.ClosedJaxpr(jaxpr_, consts)
2383+
else:
2384+
jaxpr = traced.jaxpr
23922385
if return_shape:
2393-
out_avals, _ = unzip2(out_type)
2394-
out_shapes_flat = [
2395-
ShapeDtypeStruct(a.shape, a.dtype, a.named_shape) for a in out_avals]
2396-
return closed_jaxpr, tree_unflatten(out_tree(), out_shapes_flat)
2397-
return closed_jaxpr
2386+
out = [ShapeDtypeStruct(o.shape, o.dtype, getattr(o, 'named_shape', None))
2387+
for o in jaxpr.out_avals]
2388+
return jaxpr, tree_unflatten(tree_structure(traced.out_info), out)
2389+
return jaxpr
23982390

23992391
make_jaxpr_f.__module__ = "jax"
24002392
if hasattr(fun, "__qualname__"):

jax/_src/pjit.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ class PjitInfo(NamedTuple):
171171

172172

173173
def _python_pjit_helper(jit_info, *args, **kwargs):
174-
(args_flat, params, _, out_tree, _, arg_names,
174+
(args_flat, params, in_avals, _, out_tree, _, arg_names, _,
175175
attrs_tracked) = _infer_params(jit_info, args, kwargs)
176176

177177
for arg in args_flat:
@@ -197,7 +197,7 @@ def _python_pjit_helper(jit_info, *args, **kwargs):
197197
if params['jaxpr'].consts:
198198
raise TypeError(e.args[0]) from e
199199
else:
200-
for arg, name, aval in zip(args_flat, arg_names, params['jaxpr'].in_avals):
200+
for arg, name, aval in zip(args_flat, arg_names, in_avals):
201201
try:
202202
xla.canonicalize_dtype(arg)
203203
except xla.InvalidInputException as _:
@@ -491,7 +491,7 @@ def lower(*args, **kwargs):
491491

492492
@api_boundary
493493
def eval_shape(*args, **kwargs):
494-
_, params, _, out_tree, _, _, _ = _infer_params(jit_info, args, kwargs)
494+
_, params, _, _, out_tree, _, _, _, _ = _infer_params(jit_info, args, kwargs)
495495
out_s = [None if is_unspecified(s) else s for s in params['out_shardings']]
496496
# TODO(yashkatariya): Add `Layout` to SDS.
497497
out = [api.ShapeDtypeStruct(x.shape, x.dtype, x.named_shape, sharding=s)
@@ -503,16 +503,15 @@ def trace(*args, **kwargs) -> stages.Traced:
503503
lowering_parameters = kwargs.pop(
504504
'_experimental_lowering_parameters', mlir.LoweringParameters())
505505

506-
(args_flat, params, in_tree, out_tree, donated_invars,
507-
arg_names, _) = _infer_params(jit_info, args, kwargs)
506+
(args_flat, params, in_avals, in_tree, out_tree, donated_invars,
507+
arg_names, num_consts, _) = _infer_params(jit_info, args, kwargs)
508508

509509
donate_argnums = tuple(i for i, d in enumerate(donated_invars) if d)
510-
jaxpr = params['jaxpr']
511-
args_info = stages.make_args_info(in_tree, jaxpr.in_avals, donate_argnums)
510+
args_info = stages.make_args_info(in_tree, in_avals, donate_argnums)
512511
lower_callable = partial(_resolve_and_lower, args_flat, **params,
513512
lowering_parameters=lowering_parameters)
514-
return stages.Traced(jaxpr, args_info, params["name"], out_tree,
515-
lower_callable, args_flat, arg_names)
513+
return stages.Traced(params['jaxpr'], args_info, params["name"], out_tree,
514+
lower_callable, args_flat, arg_names, num_consts)
516515

517516
wrapped = _cpp_pjit(jit_info)
518517
wrapped.lower = lower
@@ -662,8 +661,9 @@ def _infer_params(jit_info, args, kwargs):
662661
keep_unused=keep_unused,
663662
inline=inline,
664663
)
665-
return (consts + args_flat, params, in_tree, out_tree(),
666-
donated_invars, dbg.arg_names if dbg else None, attrs_tracked)
664+
return (consts + args_flat, params, in_avals, in_tree, out_tree(),
665+
donated_invars, dbg.arg_names if dbg else None, len(consts),
666+
attrs_tracked)
667667

668668
def _extract_implicit_args(
669669
in_type: Sequence[tuple[core.AbstractValue, bool]],

jax/_src/stages.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -427,17 +427,19 @@ class CompiledCallParams(NamedTuple):
427427

428428
class Traced(Stage):
429429
__slots__ = ["jaxpr", "args_info", "fun_name", "_out_tree", "_lower_callable",
430-
"_args_flat", "_arg_names"]
430+
"_args_flat", "_arg_names", "_num_consts"]
431431

432432
def __init__(self, jaxpr: core.ClosedJaxpr, args_info, fun_name, out_tree,
433-
lower_callable, args_flat=None, arg_names=None):
433+
lower_callable, args_flat=None, arg_names=None,
434+
num_consts: int = 0):
434435
self.jaxpr = jaxpr
435436
self.args_info = args_info
436437
self.fun_name = fun_name
437438
self._out_tree = out_tree
438439
self._lower_callable = lower_callable
439440
self._args_flat = args_flat
440441
self._arg_names = arg_names
442+
self._num_consts = num_consts
441443

442444
@property
443445
def out_info(self):

tests/api_test.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4058,6 +4058,16 @@ def f():
40584058
return jnp.exp(dtype(0))
40594059
f() # doesn't error
40604060

4061+
def test_vmap_make_jaxpr_close_over_tracer(self):
4062+
def run(inp):
4063+
def f(x, y):
4064+
return x + y
4065+
g = lambda x: f(x, inp)
4066+
jaxpr = jax.make_jaxpr(g)(1)
4067+
return jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 1)
4068+
4069+
jax.vmap(run)(jnp.arange(2)) # doesn't crash
4070+
40614071
def test_large_python_ints(self):
40624072
with self.assertRaises(OverflowError):
40634073
jnp.multiply(2 ** 100, 3.)

tests/lax_control_flow_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2518,7 +2518,7 @@ def f(c, a):
25182518
scan_fun = lambda c, xs: lax.scan(f, c, xs)
25192519

25202520
def new_jaxpr():
2521-
jaxpr = jax.make_jaxpr(scan_fun)(c, xs).jaxpr
2521+
jaxpr = jax.make_jaxpr(partial(scan_fun))(c, xs).jaxpr
25222522
scan = next(eqn for eqn in jaxpr.eqns if eqn.primitive.name == 'scan')
25232523
return jaxpr, scan
25242524

0 commit comments

Comments
 (0)