Skip to content

Commit 8fe8d24

Browse files
committed
Fixes to direct linearize
* Fix a bug in pjit linearization rule * Handle multiple results and zeros in fallback rule * Handle `has_aux` * Implement process_custom_vjp_call
1 parent 20236f1 commit 8fe8d24

File tree

1 file changed

+87
-25
lines changed
  • jax/_src/interpreters

1 file changed

+87
-25
lines changed

jax/_src/interpreters/ad.py

Lines changed: 87 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -143,21 +143,31 @@ def new_arg(primal_aval, nz):
143143

144144
def direct_linearize(traceable, *primals, **kwargs):
145145
has_aux = kwargs.pop('has_aux', False)
146-
assert not has_aux
147146
with core.take_current_trace() as parent_trace:
148147
tangent_trace = pe.DynamicJaxprTrace()
149148
tangents = [tangent_trace.new_arg(get_aval(p).to_tangent_aval()) for p in primals]
150149
linearize_trace = LinearizeTrace(parent_trace, tangent_trace)
151150
tracers = [LinearizeTracer(linearize_trace, p, t) for p, t in zip(primals, tangents)]
152151
with core.set_current_trace(linearize_trace):
153-
ans = traceable.call_wrapped(*tracers)
154-
152+
if has_aux:
153+
ans, aux = traceable.call_wrapped(*tracers)
154+
aux_primals = [x.primal
155+
if isinstance(x, LinearizeTracer)
156+
and x._trace.tag is linearize_trace.tag
157+
else x for x in aux]
158+
else:
159+
ans = traceable.call_wrapped(*tracers)
160+
aux = None
155161
out_primals, out_tangents = unzip2(map(linearize_trace.to_primal_tangent_pair, ans))
162+
out_tangents = map(instantiate_zeros, out_tangents)
156163
out_tangents = map(tangent_trace.to_jaxpr_tracer, out_tangents)
157164
jaxpr, consts, attrs_tracked = tangent_trace.to_jaxpr(out_tangents)
158165
out_tangents_pvals = [pe.PartialVal.unknown(core.get_aval(t)) for t in out_tangents]
159166
del attrs_tracked # TODO: attrs
160-
return out_primals, out_tangents_pvals, jaxpr, consts
167+
if has_aux:
168+
return out_primals, out_tangents_pvals, jaxpr, consts, aux_primals
169+
else:
170+
return out_primals, out_tangents_pvals, jaxpr, consts
161171

162172
def linearize(traceable, *primals, **kwargs):
163173
if config.use_direct_linearize.value:
@@ -532,22 +542,45 @@ def to_primal_tangent_pair(self, val):
532542

533543
def process_primitive(self, primitive, args, params):
534544
primals_in, tangents_in = unzip2(map(self.to_primal_tangent_pair, args))
535-
tangent_nonzeros = [type(t) is not Zero for t in tangents_in]
545+
tangent_nzs = [type(t) is not Zero for t in tangents_in]
536546
if all(type(t) is Zero for t in tangents_in):
537547
return primitive.bind_with_trace(self.parent_trace, primals_in, params)
538-
lin = primitive_linearizations.get(primitive)
539-
if lin is None:
540-
lin = partial(fallback_linearize_rule, primitive)
548+
fallback = partial(fallback_linearize_rule, primitive)
549+
lin = primitive_linearizations.get(primitive, fallback)
541550
with core.set_current_trace(self.parent_trace):
542-
primal_out, tangent_nonzeros_out, residuals, linearized = lin(
543-
tangent_nonzeros, *primals_in, **params)
551+
primal_out, tangent_nzs_out, residuals, linearized = lin(
552+
tangent_nzs, *primals_in, **params)
544553
with core.set_current_trace(self.tangent_trace):
545554
tangent_out = linearized(residuals, *tangents_in)
546555
if primitive.multiple_results:
547556
return [maybe_linearize_tracer(self, x, nz, t)
548-
for x, nz, t in zip(primal_out, tangent_nonzeros, tangent_out)]
557+
for x, nz, t in zip(primal_out, tangent_nzs_out, tangent_out)]
549558
else:
550-
return maybe_linearize_tracer(self, primal_out, tangent_nonzeros, tangent_out)
559+
return maybe_linearize_tracer(self, primal_out, tangent_nzs_out, tangent_out)
560+
561+
def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees,
562+
symbolic_zeros):
563+
primals_in, tangents_in = unzip2(map(self.to_primal_tangent_pair, tracers))
564+
if all(type(t) is Zero for t in tangents_in):
565+
return prim.bind_with_trace(self.parent_trace,
566+
(fun, fwd, bwd, *primals_in),
567+
dict(out_trees=out_trees, symbolic_zeros=symbolic_zeros))
568+
fwd_in = [(p, type(t) is not Zero) for p, t in zip(primals_in, tangents_in)]
569+
fwd_in = [x for pair in fwd_in for x in pair] # flatten
570+
with core.set_current_trace(self.parent_trace):
571+
res_and_primals_out = fwd.call_wrapped(*fwd_in)
572+
573+
_, res_tree = out_trees()
574+
res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves])
575+
avals_out = [core.get_aval(x).to_tangent_aval() for x in primals_out]
576+
577+
with core.set_current_trace(self.tangent_trace):
578+
tangents_in = map(instantiate_zeros, tangents_in)
579+
tangents_out = custom_lin_p.bind(
580+
*res, *tangents_in, num_res=res_tree.num_leaves, bwd=bwd,
581+
out_avals=avals_out, symbolic_zeros=symbolic_zeros)
582+
tangent_nzs_out = [type(t) is not Zero for t in tangents_out]
583+
return map(partial(maybe_linearize_tracer, self), primals_out, tangent_nzs_out, tangents_out)
551584

552585
def maybe_linearize_tracer(trace, primal, is_nonzero, tangent):
553586
if is_nonzero:
@@ -557,21 +590,50 @@ def maybe_linearize_tracer(trace, primal, is_nonzero, tangent):
557590
assert type(tangent) is Zero
558591
return primal
559592

560-
def fallback_linearize_rule(prim, _, *args, **kwargs):
561-
assert not prim.multiple_results
562-
563-
def call_prim(*args_):
564-
return [prim.bind(*args_, **kwargs)]
565-
566-
with config.use_direct_linearize(False):
567-
(out_primal,), (out_tangent_pval,), jaxpr, consts, *_maybe_aux = linearize(
568-
lu.wrap_init(call_prim), *args, **kwargs)
593+
def fallback_linearize_rule(prim, nonzeros, *primals, **params):
594+
jvp = primitive_jvps.get(prim)
595+
if not jvp:
596+
msg = f"Differentiation rule for '{prim}' not implemented"
597+
raise NotImplementedError(msg)
598+
current_name_stack = source_info_util.current_name_stack()
599+
with core.take_current_trace() as parent_trace:
600+
trace = pe.JaxprTrace(parent_trace, current_name_stack, core.TraceTag())
601+
tangent_avals = [get_aval(p).to_tangent_aval() for p in primals]
602+
tangent_args = [trace.new_arg(pe.PartialVal.unknown(aval)) if nz else Zero(aval)
603+
for aval, nz in zip(tangent_avals, nonzeros)]
604+
with core.set_current_trace(trace):
605+
out_primals, out_tangents = jvp(primals, tangent_args, **params)
606+
607+
if not prim.multiple_results:
608+
out_primals = [out_primals]
609+
out_tangents = [out_tangents]
610+
611+
out_primals = [trace.to_jaxpr_tracer(p).pval.get_known() for p in out_primals]
612+
out_nzs = [type(r) is not Zero for r in out_tangents]
613+
out_tangent_avals = [get_aval(p).to_tangent_aval() for p in out_primals]
614+
out_nz_tracers = [trace.to_jaxpr_tracer(r) for (r, nz) in zip(out_tangents, out_nzs) if nz]
615+
in_tracers = [t for t in tangent_args if type(t) is not Zero]
616+
jaxpr, out_consts, _ = pe.tracers_to_jaxpr(in_tracers, out_nz_tracers)
617+
618+
def linearized(residuals, *tangents):
619+
nz_tangents_in = [t for (t, nz) in zip(tangents, nonzeros) if nz]
620+
nz_tangents_out = core.eval_jaxpr(jaxpr, residuals, *nz_tangents_in)
621+
nz_tangents_out_iter = iter(nz_tangents_out)
622+
all_out_tangents = [next(nz_tangents_out_iter) if nz else Zero(aval)
623+
for (aval, nz) in zip(out_tangent_avals, out_nzs)]
624+
if prim.multiple_results:
625+
return all_out_tangents
626+
else:
627+
out_tangent, = all_out_tangents
628+
return out_tangent
569629

570-
def linearized(residuals, *tangents):
571-
out_tangent, = core.eval_jaxpr(jaxpr, residuals, *tangents)
572-
return out_tangent
630+
if prim.multiple_results:
631+
return out_primals, out_nzs, out_consts, linearized
632+
else:
633+
out_primal, = out_primals
634+
out_nz, = out_nzs
635+
return out_primal, out_nz, out_consts, linearized
573636

574-
return out_primal, True, consts, linearized
575637

576638
class LinearizeTracer(Tracer):
577639
__slots__ = ['primal', 'tangent']

0 commit comments

Comments
 (0)