Skip to content

Commit 8abb1a7

Browse files
Merge pull request #25490 from jax-ml:more-linearize-fixes
PiperOrigin-RevId: 707584597
2 parents 72e5ca9 + 2be9a69 commit 8abb1a7

File tree

1 file changed

+33
-16
lines changed
  • jax/_src/interpreters

1 file changed

+33
-16
lines changed

jax/_src/interpreters/ad.py

Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from jax._src import source_info_util
3131
from jax._src.ad_util import (
3232
add_jaxvals, replace_internal_symbolic_zeros,
33-
replace_rule_output_symbolic_zeros, Zero, zeros_like_aval)
33+
replace_rule_output_symbolic_zeros, Zero, zeros_like_aval, SymbolicZero)
3434
from jax._src.ad_util import zeros_like_p, add_jaxvals_p # noqa: F401
3535
from jax._src.api_util import flatten_fun, flatten_fun_nokwargs
3636
from jax._src.core import (Trace, Tracer, get_aval, call_p, Primitive, Literal)
@@ -157,6 +157,7 @@ def new_arg(primal_aval, nz):
157157
if attrs_tracked:
158158
raise NotImplementedError("TODO: attrs")
159159
residuals_and_primals = (*tangent_consts, *out_primals)
160+
residuals_and_primals = map(primal_trace.to_jaxpr_tracer, residuals_and_primals)
160161
primal_jaxpr, primal_consts, attrs_tracked = primal_trace.to_jaxpr(residuals_and_primals)
161162
num_residuals = len(tangent_consts)
162163
tangent_jaxpr = pe.close_jaxpr(convert_constvars_jaxpr_constvars_at_end(tangent_jaxpr))
@@ -168,8 +169,10 @@ def direct_linearize(traceable, primals, kwargs, *, has_aux=False, tag=None):
168169
with core.take_current_trace() as parent_trace:
169170
tangent_trace = pe.DynamicJaxprTrace()
170171
tangents = [tangent_trace.new_arg(get_aval(p).to_tangent_aval()) for p in primals]
172+
tangents = [Zero.from_primal_value(t) if dtype(t) == float0 else t for t in tangents]
171173
linearize_trace = LinearizeTrace(parent_trace, tangent_trace, tag=tag)
172174
tracers = [LinearizeTracer(linearize_trace, p, t) for p, t in zip(primals, tangents)]
175+
tracers = [t.full_lower() for t in tracers]
173176
with core.set_current_trace(linearize_trace):
174177
if has_aux:
175178
ans, aux = traceable.call_wrapped(*tracers)
@@ -586,20 +589,18 @@ def process_custom_jvp_call(self, prim, fun, f_jvp, tracers, *, symbolic_zeros):
586589
if all(type(t) is Zero for t in tangents_in):
587590
return prim.bind_with_trace(self.parent_trace, (fun, f_jvp, *primals_in),
588591
dict(symbolic_zeros=symbolic_zeros))
589-
with core.set_current_trace(self.parent_trace):
590-
if not symbolic_zeros:
591-
tangents_in = map(instantiate_zeros, tangents_in)
592-
else:
593-
tangents_in = map(replace_internal_symbolic_zeros, tangents_in)
594-
nonzeros_in = [type(t) is not Zero for t in tangents_in]
595592

596593
def _f_jvp(primals, tangents):
597594
outs = f_jvp.call_wrapped(*primals, *tangents)
598595
primals_out, tangents_out = split_list(outs, [len(outs) // 2])
599596
return primals_out, tangents_out
600597

601-
primals_out, tangent_nzs_out, residuals, linearized = linearize_from_jvp(
602-
_f_jvp, True, nonzeros_in, primals_in, {})
598+
with core.set_current_trace(self.parent_trace):
599+
instantiate_zeros = not symbolic_zeros
600+
nonzeros_in = [type(t) is not Zero for t in tangents_in]
601+
primals_out, tangent_nzs_out, residuals, linearized = linearize_from_jvp(
602+
_f_jvp, True, nonzeros_in, symbolic_zeros, instantiate_zeros, primals_in, {})
603+
603604
with core.set_current_trace(self.tangent_trace):
604605
tangents_out = linearized(residuals, *tangents_in)
605606
tangents_out = map(replace_rule_output_symbolic_zeros, tangents_out)
@@ -622,8 +623,8 @@ def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees,
622623
res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves])
623624
avals_out = [core.get_aval(x).to_tangent_aval() for x in primals_out]
624625

626+
tangents_in = map(instantiate_zeros, tangents_in)
625627
with core.set_current_trace(self.tangent_trace):
626-
tangents_in = map(instantiate_zeros, tangents_in)
627628
tangents_out = custom_lin_p.bind(
628629
*res, *tangents_in, num_res=res_tree.num_leaves, bwd=bwd,
629630
out_avals=avals_out, symbolic_zeros=symbolic_zeros)
@@ -666,14 +667,29 @@ def fallback_linearize_rule(_prim, _nonzeros, *primals, **params):
666667
if not jvp:
667668
msg = f"Differentiation rule for '{_prim}' not implemented"
668669
raise NotImplementedError(msg)
669-
return linearize_from_jvp(jvp, _prim.multiple_results, _nonzeros, primals, params)
670+
return linearize_from_jvp(jvp, _prim.multiple_results, _nonzeros, False, False, primals, params)
670671

671-
def linearize_from_jvp(jvp, multiple_results, nonzeros, primals, params):
672+
def linearize_from_jvp(jvp, multiple_results, nonzeros,
673+
user_facing_symbolic_zeros, instantiate_input_zeros, primals, params):
672674
current_name_stack = source_info_util.current_name_stack()
673675
with core.take_current_trace() as parent_trace:
674676
trace = pe.JaxprTrace(parent_trace, current_name_stack, core.TraceTag())
675677
tangent_avals = [get_aval(p).to_tangent_aval() for p in primals]
676-
tangent_args = [trace.new_arg(pe.PartialVal.unknown(aval)) if nz else Zero(aval)
678+
679+
def make_zero(aval):
680+
if instantiate_input_zeros:
681+
return zeros_like_aval(aval)
682+
elif user_facing_symbolic_zeros:
683+
return SymbolicZero(aval)
684+
else:
685+
return Zero(aval)
686+
687+
if user_facing_symbolic_zeros:
688+
zero_type = SymbolicZero
689+
else:
690+
zero_type = Zero
691+
692+
tangent_args = [trace.new_arg(pe.PartialVal.unknown(aval)) if nz else make_zero(aval)
677693
for aval, nz in zip(tangent_avals, nonzeros)]
678694
with core.set_current_trace(trace):
679695
out_primals, out_tangents = jvp(primals, tangent_args, **params)
@@ -683,10 +699,11 @@ def linearize_from_jvp(jvp, multiple_results, nonzeros, primals, params):
683699
out_tangents = [out_tangents]
684700

685701
out_primals = [trace.to_jaxpr_tracer(p).pval.get_known() for p in out_primals]
686-
out_nzs = [type(r) is not Zero for r in out_tangents]
702+
out_nzs = [type(r) is not zero_type for r in out_tangents]
687703
out_tangent_avals = [get_aval(p).to_tangent_aval() for p in out_primals]
688-
out_nz_tracers = [trace.to_jaxpr_tracer(r) for (r, nz) in zip(out_tangents, out_nzs) if nz]
689-
in_tracers = [t for t in tangent_args if type(t) is not Zero]
704+
out_nz_tracers = [trace.instantiate_const(trace.to_jaxpr_tracer(r))
705+
for (r, nz) in zip(out_tangents, out_nzs) if nz]
706+
in_tracers = [t for t, nz in zip(tangent_args, nonzeros) if nz]
690707
jaxpr, out_consts, _ = pe.tracers_to_jaxpr(in_tracers, out_nz_tracers)
691708

692709
def linearized(residuals, *tangents):

0 commit comments

Comments
 (0)