Skip to content

Commit b1d1dcf

Browse files
committed
Add linearization rule for pjit_p
1 parent 73fa0f4 commit b1d1dcf

File tree

5 files changed

+112
-27
lines changed

5 files changed

+112
-27
lines changed

jax/_src/interpreters/ad.py

Lines changed: 51 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -105,22 +105,56 @@ def jvp_subtrace_aux(f, store, tag, primals, tangents):
105105
store.store(aux_primals)
106106
return out_primals, out_tangents
107107

108+
def convert_constvars_jaxpr_constvars_at_end(jaxpr: core.Jaxpr) -> core.Jaxpr:
109+
dbg = jaxpr.debug_info and jaxpr.debug_info._replace(
110+
arg_names=jaxpr.debug_info.arg_names + (None,) * len(jaxpr.constvars))
111+
return core.Jaxpr(constvars=(),
112+
invars=jaxpr.invars + jaxpr.constvars,
113+
outvars=jaxpr.outvars, eqns=jaxpr.eqns,
114+
effects=jaxpr.effects, debug_info=dbg)
115+
116+
def linearize_jaxpr(jaxpr, nonzeros):
117+
primal_trace = pe.DynamicJaxprTrace()
118+
tangent_trace = pe.DynamicJaxprTrace()
119+
lin_trace = LinearizeTrace(primal_trace, tangent_trace)
120+
121+
def new_arg(primal_aval, nz):
122+
primal = primal_trace.new_arg(primal_aval)
123+
tangent_aval = primal_aval.to_tangent_aval()
124+
tangent = tangent_trace.new_arg(tangent_aval) if nz else Zero(tangent_aval)
125+
return LinearizeTracer(lin_trace, primal, tangent)
126+
127+
tracers = [new_arg(v.aval, nz) for (v, nz) in zip(jaxpr.jaxpr.invars, nonzeros)]
128+
with core.set_current_trace(lin_trace):
129+
ans = core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *tracers)
130+
131+
out_primals, out_tangents = unzip2(map(lin_trace.to_primal_tangent_pair, ans))
132+
nzs_out = [type(t) is not Zero for t in out_tangents]
133+
out_tangents = [tangent_trace.to_jaxpr_tracer(t)
134+
for (nz, t) in zip(nzs_out, out_tangents) if nz]
135+
tangent_jaxpr, tangent_consts, attrs_tracked = tangent_trace.to_jaxpr(out_tangents)
136+
del attrs_tracked # TODO: attrs
137+
residuals_and_primals = (*tangent_consts, *out_primals)
138+
primal_jaxpr, primal_consts, attrs_tracked = primal_trace.to_jaxpr(residuals_and_primals)
139+
num_residuals = len(tangent_consts)
140+
tangent_jaxpr = pe.close_jaxpr(convert_constvars_jaxpr_constvars_at_end(tangent_jaxpr))
141+
del attrs_tracked # TODO: attrs
142+
return core.ClosedJaxpr(primal_jaxpr, primal_consts), num_residuals, nzs_out, tangent_jaxpr
143+
108144
def direct_linearize(traceable, *primals, **kwargs):
109145
has_aux = kwargs.pop('has_aux', False)
110146
assert not has_aux
111147
with core.take_current_trace() as parent_trace:
112-
frame = pe.JaxprStackFrame()
113-
tangent_trace = pe.DynamicJaxprTrace(frame)
148+
tangent_trace = pe.DynamicJaxprTrace()
114149
tangents = [tangent_trace.new_arg(get_aval(p).to_tangent_aval()) for p in primals]
115-
tag = core.TraceTag()
116-
linearize_trace = LinearizeTrace(parent_trace, tangent_trace, tag)
150+
linearize_trace = LinearizeTrace(parent_trace, tangent_trace)
117151
tracers = [LinearizeTracer(linearize_trace, p, t) for p, t in zip(primals, tangents)]
118152
with core.set_current_trace(linearize_trace):
119153
ans = traceable.call_wrapped(*tracers)
120154

121155
out_primals, out_tangents = unzip2(map(linearize_trace.to_primal_tangent_pair, ans))
122156
out_tangents = map(tangent_trace.to_jaxpr_tracer, out_tangents)
123-
jaxpr, consts, attrs_tracked = frame.to_jaxpr(tangent_trace, out_tangents)
157+
jaxpr, consts, attrs_tracked = tangent_trace.to_jaxpr(out_tangents)
124158
out_tangents_pvals = [pe.PartialVal.unknown(core.get_aval(t)) for t in out_tangents]
125159
del attrs_tracked # TODO: attrs
126160
return out_primals, out_tangents_pvals, jaxpr, consts
@@ -469,8 +503,8 @@ def _primal_tangent_shapes_match(primal, tangent):
469503

470504
class LinearizeTrace(Trace):
471505

472-
def __init__(self, parent_trace, tangent_trace, tag):
473-
self.tag = tag
506+
def __init__(self, parent_trace, tangent_trace, tag=None):
507+
self.tag = core.TraceTag() if tag is None else tag
474508
self.parent_trace = parent_trace
475509
self.tangent_trace = tangent_trace
476510

@@ -509,18 +543,20 @@ def maybe_linearize_tracer(trace, primal, is_nonzero, tangent):
509543
return primal
510544

511545
def fallback_linearize_rule(prim, _, *args, **kwargs):
546+
assert not prim.multiple_results
547+
512548
def call_prim(*args_):
513-
return prim.bind(*args_, **kwargs)
549+
return [prim.bind(*args_, **kwargs)]
550+
514551
with config.use_direct_linearize(False):
515-
out_primals, out_tangents_pvals, jaxpr, consts, *_maybe_aux = linearize(
552+
(out_primal,), (out_tangent_pval,), jaxpr, consts, *_maybe_aux = linearize(
516553
lu.wrap_init(call_prim), *args, **kwargs)
554+
517555
def linearized(residuals, *tangents):
518-
tangents_out = iter(core.eval_jaxpr(jaxpr, residuals, *tangents))
519-
full_out = [pval.get_known() if pval.is_known() else next(tangents_out)
520-
for pval in out_tangents_pvals]
521-
assert next(tangents_out, None) is None
522-
return full_out
523-
return out_primals, [True for _ in out_primals], consts, linearized
556+
out_tangent, = core.eval_jaxpr(jaxpr, residuals, *tangents)
557+
return out_tangent
558+
559+
return out_primal, True, consts, linearized
524560

525561
class LinearizeTracer(Tracer):
526562
__slots__ = ['primal', 'tangent']

jax/_src/interpreters/partial_eval.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1575,6 +1575,7 @@ def get_referent(self):
15751575
val = frame.constvar_to_val.get(frame.tracer_to_var.get(id(self)))
15761576
return self if val is None else get_referent(val)
15771577

1578+
15781579
def _dynamic_jaxpr_tracer_shaped_abstractify(x):
15791580
return x.aval
15801581
api_util._shaped_abstractify_handlers[DynamicJaxprTracer] = _dynamic_jaxpr_tracer_shaped_abstractify
@@ -1805,8 +1806,8 @@ def vars_in_shape(aval: AbstractValue) -> Sequence[Var]:
18051806

18061807

18071808
class DynamicJaxprTrace(core.Trace):
1808-
def __init__(self, frame):
1809-
self.frame = frame
1809+
def __init__(self):
1810+
self.frame = JaxprStackFrame()
18101811

18111812
def invalidate(self):
18121813
# avoid cyclic refs
@@ -2068,6 +2069,9 @@ def transpose_jaxpr_thunk():
20682069
self.frame.add_eqn(eqn)
20692070
return out_tracers
20702071

2072+
def to_jaxpr(self, out_tracers: Sequence[Tracer]):
2073+
return self.frame.to_jaxpr(self, out_tracers)
2074+
20712075

20722076
custom_staging_rules: dict[Primitive, Callable] = {}
20732077

@@ -2166,19 +2170,17 @@ def trace_to_jaxpr_dynamic(
21662170
list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]]]:
21672171
keep_inputs = [True] * len(in_avals) if keep_inputs is None else keep_inputs
21682172

2169-
frame = JaxprStackFrame()
2170-
frame.debug_info = debug_info
2171-
2172-
trace = DynamicJaxprTrace(frame)
2173+
trace = DynamicJaxprTrace()
2174+
trace.frame.debug_info = debug_info
21732175
with core.ensure_no_leaks(trace), source_info_util.reset_name_stack():
21742176
in_tracers = _input_type_to_tracers(trace.new_arg, in_avals)
21752177
in_tracers = [t for t, keep in zip(in_tracers, keep_inputs) if keep]
21762178
with core.set_current_trace(trace):
21772179
ans = fun.call_wrapped(*in_tracers)
21782180

21792181
out_tracers = map(trace.to_jaxpr_tracer, ans)
2180-
jaxpr, consts, attrs_tracked = frame.to_jaxpr(trace, out_tracers)
2181-
del trace, fun, frame, in_tracers, out_tracers, ans
2182+
jaxpr, consts, attrs_tracked = trace.to_jaxpr(out_tracers)
2183+
del trace, fun, in_tracers, out_tracers, ans
21822184

21832185
config.enable_checks.value and core.check_jaxpr(jaxpr)
21842186
return jaxpr, [v.aval for v in jaxpr.outvars], consts, attrs_tracked
@@ -2188,7 +2190,7 @@ def trace_to_jaxpr_dynamic2(
21882190
fun: lu.WrappedFun, debug_info: DebugInfo | None = None
21892191
) -> tuple[Jaxpr, OutputType, list[Any]]:
21902192

2191-
trace = DynamicJaxprTrace(JaxprStackFrame())
2193+
trace = DynamicJaxprTrace()
21922194
with core.ensure_no_leaks(trace), source_info_util.reset_name_stack():
21932195
trace.frame.debug_info = debug_info
21942196
in_avals, keep_inputs = unzip2(fun.in_type)

jax/_src/lax/lax.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2400,9 +2400,10 @@ def _sin_lowering(ctx, x):
24002400
return sine(ctx, x)
24012401
return _nary_lower_hlo(hlo.sine, ctx, x)
24022402

2403-
def _sin_p_lin(_, x):
2403+
def _sin_p_lin(nzs, x):
2404+
nz, = nzs
24042405
cos_x = cos(x) # TODO: allow this to happen in the linearized computation (need to fix backward_pass)
2405-
return (sin_p.bind(x), True, cos_x, lambda cos_x_, t: mul(t, cos_x_))
2406+
return (sin_p.bind(x), nz, cos_x, lambda cos_x_, t: mul(t, cos_x_))
24062407

24072408
sin_p = standard_unop(_float | _complex, 'sin')
24082409
ad.defjvp(sin_p, lambda g, x: mul(g, cos(x)))

jax/_src/pjit.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2107,6 +2107,52 @@ def _filter_zeros(is_nz_l, l):
21072107
ad.primitive_jvps[pjit_p] = _pjit_jvp
21082108

21092109

2110+
def _pjit_linearization(nzs, *primals_in, jaxpr,
2111+
in_shardings, out_shardings, in_layouts, out_layouts,
2112+
resource_env, donated_invars, name, keep_unused, inline,
2113+
compiler_options_kvs):
2114+
primal_jaxpr, num_residuals, nzs_out, tangent_jaxpr = ad.linearize_jaxpr(jaxpr, nzs)
2115+
# constvars will become residuals. Move them to the end of the ordinary args.
2116+
res_shardings = (UNSPECIFIED,) * num_residuals
2117+
res_layouts = (None,) * num_residuals
2118+
res_donated = (False,) * num_residuals
2119+
def tangent_fun(consts_, *tangents):
2120+
tangents_nz = _filter_zeros(nzs, tangents)
2121+
assert len(consts_) == num_residuals
2122+
return pjit_p.bind(*(*tangents_nz, *consts_),
2123+
jaxpr=tangent_jaxpr,
2124+
in_shardings=_filter_zeros(nzs, in_shardings) + res_shardings,
2125+
out_shardings=_filter_zeros(nzs_out, out_shardings),
2126+
in_layouts=_filter_zeros(nzs, in_layouts) + res_layouts,
2127+
out_layouts=_filter_zeros(nzs_out, out_layouts),
2128+
resource_env=resource_env,
2129+
donated_invars=_filter_zeros(nzs, donated_invars) + res_donated,
2130+
name=name,
2131+
keep_unused=keep_unused,
2132+
inline=inline,
2133+
compiler_options_kvs=compiler_options_kvs)
2134+
2135+
def _filter_zeros(is_nz_l, l):
2136+
return tuple(x for nz, x in zip(is_nz_l, l) if nz)
2137+
2138+
ans = pjit_p.bind(*primals_in, jaxpr=primal_jaxpr,
2139+
in_shardings=in_shardings,
2140+
out_shardings=(*res_shardings, *out_shardings),
2141+
in_layouts=in_layouts,
2142+
out_layouts=(*res_layouts, *out_layouts),
2143+
resource_env=resource_env,
2144+
donated_invars=donated_invars,
2145+
name=name,
2146+
keep_unused=keep_unused,
2147+
inline=inline,
2148+
compiler_options_kvs=compiler_options_kvs)
2149+
residuals_ans, primal_ans = split_list(ans, [num_residuals])
2150+
2151+
return primal_ans, nzs_out, residuals_ans, tangent_fun
2152+
2153+
ad.primitive_linearizations[pjit_p] = _pjit_linearization
2154+
2155+
21102156
def _pjit_partial_eval(trace, *in_tracers,
21112157
jaxpr, in_shardings, out_shardings,
21122158
in_layouts, out_layouts, resource_env, donated_invars,

tests/api_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4818,7 +4818,7 @@ def check_invariant_to_use_direct_linearize(f):
48184818
self.assertEqual(ans1, ans2)
48194819

48204820
def sin_of_sin(x):
4821-
return lax.sin(lax.sin(x))
4821+
return lax.sin(jax.jit(lax.sin)(x))
48224822

48234823
check_invariant_to_use_direct_linearize(lambda: jax.grad(sin_of_sin)(1.0))
48244824

0 commit comments

Comments
 (0)