Skip to content

Commit d69fe1f

Browse files
Merge pull request jax-ml#24873 from jax-ml:no-gen-linear-util
PiperOrigin-RevId: 696315182
2 parents 842d93e + 1c9b23c commit d69fe1f

File tree

17 files changed

+311
-287
lines changed

17 files changed

+311
-287
lines changed

jax/_src/api_util.py

Lines changed: 33 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -68,11 +68,13 @@ def _ensure_str_tuple(x: str | Iterable[str]) -> tuple[str, ...]:
6868
else:
6969
return tuple(map(_ensure_str, x))
7070

71-
@lu.transformation_with_aux
72-
def flatten_fun(in_tree, *args_flat):
71+
@lu.transformation_with_aux2
72+
def flatten_fun(f, store, in_tree, *args_flat):
7373
py_args, py_kwargs = tree_unflatten(in_tree, args_flat)
74-
ans = yield py_args, py_kwargs
75-
yield tree_flatten(ans)
74+
ans = f(*py_args, **py_kwargs)
75+
ans, out_tree = tree_flatten(ans)
76+
store.store(out_tree)
77+
return ans
7678

7779
def apply_flat_fun(fun, io_tree, *py_args):
7880
in_tree_expected, out_tree = io_tree
@@ -82,11 +84,13 @@ def apply_flat_fun(fun, io_tree, *py_args):
8284
ans = fun(*args)
8385
return tree_unflatten(out_tree, ans)
8486

85-
@lu.transformation_with_aux
86-
def flatten_fun_nokwargs(in_tree, *args_flat):
87+
@lu.transformation_with_aux2
88+
def flatten_fun_nokwargs(f, store, in_tree, *args_flat):
8789
py_args = tree_unflatten(in_tree, args_flat)
88-
ans = yield py_args, {}
89-
yield tree_flatten(ans)
90+
ans = f(*py_args)
91+
ans, out_tree = tree_flatten(ans)
92+
store.store(out_tree)
93+
return ans
9094

9195
def apply_flat_fun_nokwargs(fun, io_tree, py_args):
9296
in_tree_expected, out_tree = io_tree
@@ -118,17 +122,18 @@ def flattened_fun_in_tree(
118122
else:
119123
return in_tree, lambda: out_tree_store.val, has_kwargs
120124

121-
@lu.transformation_with_aux
122-
def flatten_fun_nokwargs2(in_tree, *args_flat):
125+
@lu.transformation_with_aux2
126+
def flatten_fun_nokwargs2(f, store, in_tree, *args_flat):
123127
py_args = tree_unflatten(in_tree, args_flat)
124-
pair = yield py_args, {}
128+
pair = f(*py_args)
125129
if not isinstance(pair, (list, tuple)) or len(pair) != 2:
126130
raise TypeError("expected function with aux output to return a two-element "
127131
f"tuple, but got type {type(pair)} with value {pair!r}")
128132
ans, aux = pair
129133
ans_flat, ans_tree = tree_flatten(ans)
130134
aux_flat, aux_tree = tree_flatten(aux)
131-
yield (ans_flat, aux_flat), (ans_tree, aux_tree)
135+
store.store((ans_tree, aux_tree))
136+
return ans_flat, aux_flat
132137

133138
class _HashableWithStrictTypeEquality:
134139
"""Box object used when comparing static arguments as a jit key.
@@ -277,18 +282,16 @@ def argnums_partial_except(f: lu.WrappedFun, static_argnums: tuple[int, ...],
277282

278283
return _argnums_partial(f, dyn_argnums, tuple(fixed_args)), dyn_args
279284

280-
@lu.transformation
281-
def _argnums_partial(dyn_argnums, fixed_args, *dyn_args, **kwargs):
285+
@lu.transformation2
286+
def _argnums_partial(f, dyn_argnums, fixed_args, *dyn_args, **kwargs):
282287
sentinel = object()
283288
args = [sentinel] * (len(fixed_args) + len(dyn_args))
284289
for i, arg in zip(dyn_argnums, dyn_args):
285290
args[i] = arg
286291
fixed_args_ = iter(fixed_args)
287292
args = [next(fixed_args_).val if x is sentinel else x for x in args]
288293
assert next(fixed_args_, sentinel) is sentinel
289-
ans = yield args, kwargs
290-
yield ans
291-
294+
return f(*args, **kwargs)
292295

293296
def argnames_partial_except(f: lu.WrappedFun, static_argnames: tuple[str, ...],
294297
kwargs: dict[str, Any]):
@@ -311,11 +314,10 @@ def argnames_partial_except(f: lu.WrappedFun, static_argnames: tuple[str, ...],
311314

312315
return _argnames_partial(f, WrapKwArgs(fixed_kwargs)), dyn_kwargs
313316

314-
@lu.transformation
315-
def _argnames_partial(fixed_kwargs: WrapKwArgs, *args, **dyn_kwargs):
317+
@lu.transformation2
318+
def _argnames_partial(f, fixed_kwargs: WrapKwArgs, *args, **dyn_kwargs):
316319
kwargs = dict({k: v.val for k, v in fixed_kwargs.val.items()}, **dyn_kwargs)
317-
ans = yield args, kwargs
318-
yield ans
320+
return f(*args, **kwargs)
319321

320322

321323
@lru_cache(maxsize=4096)
@@ -435,9 +437,9 @@ def flat_out_axes(
435437
f, out_axes = _flat_out_axes(f, tuple(leaves), treedef)
436438
return f, HashableFunction(out_axes, closure=(tuple(leaves), treedef))
437439

438-
@lu.transformation_with_aux
439-
def _flat_out_axes(leaves, treedef, *args, **kwargs):
440-
ans = yield args, kwargs
440+
@lu.transformation_with_aux2
441+
def _flat_out_axes(f, store, leaves, treedef, *args, **kwargs):
442+
ans = f(*args, **kwargs)
441443
spec = tree_unflatten(treedef, leaves)
442444
try:
443445
spec_flat = tuple(broadcast_prefix(spec, ans, is_leaf=lambda x: x is None))
@@ -449,7 +451,8 @@ def _flat_out_axes(leaves, treedef, *args, **kwargs):
449451
"that the `out_axes` argument to `pmap` is a pytree prefix of the "
450452
"pmapped function's output.")
451453
raise ValueError(msg) from None
452-
yield ans, spec_flat
454+
store.store(spec_flat)
455+
return ans
453456

454457
def check_callable(fun):
455458
# In Python 3.10+, the only thing stopping us from supporting staticmethods
@@ -683,11 +686,12 @@ def _arg_names(fn_signature, args, kwargs, static_argnums, static_argnames,
683686
return tuple(f'{name}{keystr(path)}' for name, x in ba.arguments.items()
684687
for path, l in generate_key_paths(x) if l is not static)
685688

686-
@lu.transformation_with_aux
687-
def result_paths(*args, **kwargs):
689+
@lu.transformation_with_aux2
690+
def result_paths(f, store, *args, **kwargs):
688691
"linear_util transform to get output pytree paths of pre-flattened function."
689-
ans = yield args, kwargs
690-
yield ans, [keystr(path) for path, _ in generate_key_paths(ans)]
692+
ans = f(*args, **kwargs)
693+
store.store([keystr(path) for path, _ in generate_key_paths(ans)])
694+
return ans
691695

692696
def jaxpr_debug_info(jaxpr: core.Jaxpr, trace_debug: TracingDebugInfo | None,
693697
result_paths: tuple[str, ...] | None = None,

jax/_src/checkify.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -330,11 +330,12 @@ def update_error(error, pred, code, metadata, payload, effect_type):
330330

331331
## Checkify transformation for plumbing functional error values.
332332

333-
@lu.transformation_with_aux
334-
def _flatten_and_get_error_metadata_thunk(*invals):
335-
error, out = yield invals, {}
333+
@lu.transformation_with_aux2
334+
def _flatten_and_get_error_metadata_thunk(f, store, *invals):
335+
error, out = f(*invals)
336336
out_vals, out_tree = jtu.tree_flatten((error, out))
337-
yield out_vals, (out_tree, set(error._pred.keys()))
337+
store.store((out_tree, set(error._pred.keys())))
338+
return out_vals
338339

339340
def default_checkify_rule(primitive: core.Primitive, error: Error,
340341
enabled_errors, *invals: core.Value,
@@ -438,10 +439,12 @@ def checkify_jaxpr_flat_hashable(jaxpr, hashable_consts, enabled_errors,
438439
consts = tuple(c.x for c in hashable_consts)
439440
return checkify_jaxpr_flat(jaxpr, consts, enabled_errors, err_tree, *args)
440441

441-
@lu.transformation_with_aux
442-
def flatten_fun_output(*args):
443-
ans = yield args, {}
444-
yield tree_flatten(ans)
442+
@lu.transformation_with_aux2
443+
def flatten_fun_output(f, store, *args):
444+
ans = f(*args)
445+
ans, out_tree = tree_flatten(ans)
446+
store.store(out_tree)
447+
return ans
445448

446449

447450
def _reduce_any_error(error: Error):

jax/_src/custom_derivatives.py

Lines changed: 25 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -75,13 +75,14 @@ def _zeros_like_pytree(x):
7575

7676

7777
# like the api_util.py function, but also grabs output avals for error checking
78-
@lu.transformation_with_aux
79-
def _flatten_fun_nokwargs(in_tree, *args_flat):
78+
@lu.transformation_with_aux2
79+
def _flatten_fun_nokwargs(f, store, in_tree, *args_flat):
8080
py_args = tree_unflatten(in_tree, args_flat)
81-
ans = yield py_args, {}
81+
ans = f(*py_args)
8282
ans_flat, ans_tree = tree_flatten(ans)
8383
ans_avals = [core.get_aval(x) for x in ans_flat]
84-
yield ans_flat, (ans_tree, ans_avals)
84+
store.store((ans_tree, ans_avals))
85+
return ans_flat
8586

8687

8788
### JVPs
@@ -266,18 +267,18 @@ def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable
266267
def _add_args(f, extra_args):
267268
return _add_args_(f, tuple(Unhashable(arg) for arg in extra_args))
268269

269-
@lu.transformation
270-
def _add_args_(extra_args, *args, **kwargs):
270+
@lu.transformation2
271+
def _add_args_(f, extra_args, *args, **kwargs):
271272
extra_args = tuple(arg.val for arg in extra_args)
272273
all_args = (extra_args + args)
273-
yield (yield all_args, kwargs)
274+
return f(*all_args, **kwargs)
274275

275-
@partial(lu.transformation_with_aux, use_eq_store=True)
276-
def _flatten_jvp(primal_name, jvp_name, in_tree, maybe_out_type, *args):
276+
@partial(lu.transformation_with_aux2, use_eq_store=True)
277+
def _flatten_jvp(f, store, primal_name, jvp_name, in_tree, maybe_out_type, *args):
277278
primals_in, tangents_in = split_list(args, [len(args) // 2])
278279
py_primals = tree_unflatten(in_tree, primals_in)
279280
py_tangents = tree_unflatten(in_tree, tangents_in)
280-
pair_out = yield (py_primals, py_tangents), {}
281+
pair_out = f(py_primals, py_tangents)
281282
if not isinstance(pair_out, (list, tuple)) or len(pair_out) != 2:
282283
msg = (f"Custom JVP rule {jvp_name} for function {primal_name} "
283284
"must produce a pair (list or tuple of length two) representing "
@@ -348,7 +349,8 @@ def _flatten_jvp(primal_name, jvp_name, in_tree, maybe_out_type, *args):
348349
if av_et != av_t)
349350

350351
raise TypeError(msg.format('\n'.join(disagreements)))
351-
yield primals_out + tangents_out, (out_tree, primal_avals)
352+
store.store((out_tree, primal_avals))
353+
return primals_out + tangents_out
352354

353355
class CustomJVPCallPrimitive(core.Primitive):
354356
multiple_results = True
@@ -652,15 +654,15 @@ def _check_for_tracers(x):
652654
"arguments should typically not be indicated as nondiff_argnums.")
653655
raise UnexpectedTracerError(msg)
654656

655-
@partial(lu.transformation_with_aux, use_eq_store=True)
656-
def _flatten_fwd(symbolic_zeros, primal_name, fwd_name, in_tree, maybe_out_type,
657+
@partial(lu.transformation_with_aux2, use_eq_store=True)
658+
def _flatten_fwd(f, store, symbolic_zeros, primal_name, fwd_name, in_tree, maybe_out_type,
657659
*args):
658660
if symbolic_zeros:
659661
args = [CustomVJPPrimal(x, z) for x, z in zip(args[::2], args[1::2])]
660662
else:
661663
args = args[::2]
662664
py_args = tree_unflatten(in_tree, args)
663-
pair_out = yield py_args, {}
665+
pair_out = f(*py_args)
664666
if not isinstance(pair_out, (list, tuple)) or len(pair_out) != 2:
665667
msg = (f"Custom VJP fwd rule {fwd_name} for function {primal_name} "
666668
"must produce a pair (list or tuple of length two) where the first "
@@ -710,16 +712,17 @@ def _flatten_fwd(symbolic_zeros, primal_name, fwd_name, in_tree, maybe_out_type,
710712
"shapes/dtypes of:\n"
711713
f""" {str(ty_tree_).replace("'", "")}""")
712714
raise TypeError(m)
713-
yield (*res, *primals_out), (out_tree, res_tree)
715+
store.store((out_tree, res_tree))
716+
return (*res, *primals_out)
714717

715-
@lu.transformation
716-
def _flatten_bwd(in_tree, in_avals, out_trees, *args):
718+
@lu.transformation2
719+
def _flatten_bwd(f, in_tree, in_avals, out_trees, *args):
717720
out_tree, res_tree = out_trees()
718721
assert len(args) == res_tree.num_leaves + out_tree.num_leaves
719722
res, cts_out = split_list(args, [res_tree.num_leaves])
720723
py_res = tree_unflatten(res_tree, res)
721724
py_cts_out = tree_unflatten(out_tree, cts_out)
722-
py_cts_in = yield (py_res, py_cts_out), {}
725+
py_cts_in = f(py_res, py_cts_out)
723726
if isinstance(py_cts_in, list) and len(py_cts_in) == len(treedef_children(in_tree)):
724727
py_cts_in = tuple(py_cts_in)
725728
# For each None in py_cts_in, indicating an argument for which the rule
@@ -775,7 +778,7 @@ def append(x, d):
775778
f"to an input of shape/dtype {a.str_short()}.")
776779
raise ValueError(msg)
777780
results.append(ct)
778-
yield results
781+
return results
779782

780783
# TODO(mattjj): remove both these exceptions to cotangent compatibility check
781784
def _temporary_dtype_exception(a, a_) -> bool:
@@ -1425,11 +1428,11 @@ def fun_jaxpr_thunk():
14251428

14261429
return wrapped_fwd
14271430

1428-
@lu.transformation
1429-
def _fix_fwd_args(*args):
1431+
@lu.transformation2
1432+
def _fix_fwd_args(f, *args):
14301433
args = [(x, True) for x in args]
14311434
args = [x for pair in args for x in pair]
1432-
yield (yield args, {})
1435+
return f(*args)
14331436

14341437
def _remat_opt_impl(
14351438
*args,

0 commit comments

Comments
 (0)