Skip to content

Commit dbc3bcd

Browse files
committed
Apply forwarding in pjit linearization rule to avoid intermediate copies.
1 parent 5a3fc60 commit dbc3bcd

File tree

2 files changed

+22
-5
lines changed

2 files changed

+22
-5
lines changed

jax/_src/interpreters/ad.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,11 @@ def new_arg(trace, primal_aval, nz):
194194
tangent_trace.invalidate()
195195
if attrs_tracked:
196196
raise NotImplementedError("TODO: attrs")
197+
tangent_jaxpr, used_consts, _ = pe.dce_jaxpr_consts(
198+
tangent_jaxpr, [True] * len(tangent_jaxpr.outvars),
199+
[False] * len(tangent_jaxpr.constvars) + [True] * len(tangent_jaxpr.invars))
200+
tangent_consts = [c for c, used in zip(tangent_consts, used_consts) if used]
201+
197202
residuals_and_primals = (*tangent_consts, *out_primals)
198203
residuals_and_primals = map(primal_trace.to_jaxpr_tracer, residuals_and_primals)
199204
primal_jaxpr, primal_consts, attrs_tracked = primal_trace.to_jaxpr(residuals_and_primals, debug_info)
@@ -871,6 +876,10 @@ def make_zero(aval):
871876
for (r, nz) in zip(out_tangents, out_nzs) if nz]
872877
in_tracers = [t for t, nz in zip(tangent_args, nonzeros) if nz]
873878
jaxpr, out_consts, _ = pe.tracers_to_jaxpr(in_tracers, out_nz_tracers, jvp.debug_info)
879+
jaxpr, used_consts, _ = pe.dce_jaxpr_consts(
880+
jaxpr, [True] * len(jaxpr.outvars),
881+
[False] * len(jaxpr.constvars) + [True] * len(jaxpr.invars))
882+
out_consts = [c for used, c in zip(used_consts, out_consts) if used]
874883

875884
def linearized(residuals, *tangents):
876885
nz_tangents_in = [t for (t, nz) in zip(tangents, nonzeros) if nz]

jax/_src/pjit.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2076,14 +2076,22 @@ def _pjit_linearization(nzs, *primals_in, jaxpr,
20762076
donated_invars, ctx_mesh, name, keep_unused, inline,
20772077
compiler_options_kvs):
20782078
primal_jaxpr, num_residuals, nzs_out, tangent_jaxpr = ad.linearize_jaxpr(jaxpr, nzs)
2079-
# constvars will become residuals. Move them to the end of the ordinary args.
20802079
res_shardings = (UNSPECIFIED,) * num_residuals
20812080
res_layouts = (None,) * num_residuals
20822081
res_donated = (False,) * num_residuals
2082+
2083+
in_fwd = pe._jaxpr_forwarding(primal_jaxpr.jaxpr)
2084+
in_fwd, _ = split_list(in_fwd, [num_residuals])
2085+
keep = tuple(f is None for f in in_fwd) + (True,) * len(out_shardings)
2086+
primal_jaxpr = pe.prune_closed_jaxpr_outputs(primal_jaxpr, keep)
2087+
num_residuals = sum(f is None for f in in_fwd)
2088+
20832089
def tangent_fun(consts_, *tangents):
2090+
consts_it = iter(consts_)
2091+
res = [next(consts_it) if f is None else primals_in[f] for f in in_fwd]
2092+
assert next(consts_it, None) is None
20842093
tangents_nz = _filter_zeros(nzs, tangents)
2085-
assert len(consts_) == num_residuals
2086-
nz_tangents_out = pjit_p.bind(*(*tangents_nz, *consts_),
2094+
nz_tangents_out = pjit_p.bind(*(*tangents_nz, *res),
20872095
jaxpr=tangent_jaxpr,
20882096
in_shardings=_filter_zeros(nzs, in_shardings) + res_shardings,
20892097
out_shardings=_filter_zeros(nzs_out, out_shardings),
@@ -2106,9 +2114,9 @@ def _filter_zeros(is_nz_l, l):
21062114

21072115
ans = pjit_p.bind(*primals_in, jaxpr=primal_jaxpr,
21082116
in_shardings=in_shardings,
2109-
out_shardings=(*res_shardings, *out_shardings),
2117+
out_shardings=(*res_shardings[:num_residuals], *out_shardings),
21102118
in_layouts=in_layouts,
2111-
out_layouts=(*res_layouts, *out_layouts),
2119+
out_layouts=(*res_layouts[:num_residuals], *out_layouts),
21122120
donated_invars=donated_invars,
21132121
ctx_mesh=ctx_mesh,
21142122
name=name,

0 commit comments

Comments
 (0)