Skip to content

Commit 3d79df2

Browse files
Merge pull request #25048 from jax-ml:linearization-rule-signature
PiperOrigin-RevId: 699007033
2 parents 344d0d9 + 170718c commit 3d79df2

File tree

3 files changed

+24
-15
lines changed

3 files changed

+24
-15
lines changed

jax/_src/interpreters/ad.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -483,39 +483,44 @@ def to_primal_tangent_pair(self, val):
483483

484484
def process_primitive(self, primitive, args, params):
485485
primals_in, tangents_in = unzip2(map(self.to_primal_tangent_pair, args))
486+
tangent_nonzeros = [type(t) is not Zero for t in tangents_in]
486487
if all(type(t) is Zero for t in tangents_in):
487488
return primitive.bind_with_trace(self.parent_trace, primals_in, params)
488489
lin = primitive_linearizations.get(primitive)
489490
if lin is None:
490491
lin = partial(fallback_linearize_rule, primitive)
491492
with core.set_current_trace(self.parent_trace):
492-
primal_out, linearized = lin(*primals_in, **params)
493+
primal_out, tangent_nonzeros_out, residuals, linearized = lin(
494+
tangent_nonzeros, *primals_in, **params)
493495
with core.set_current_trace(self.tangent_trace):
494-
tangent_out = linearized(*tangents_in)
496+
tangent_out = linearized(residuals, *tangents_in)
495497
if primitive.multiple_results:
496-
return [maybe_linearize_tracer(self, x, t) for x, t in zip(primal_out, tangent_out)]
498+
return [maybe_linearize_tracer(self, x, nz, t)
499+
for x, nz, t in zip(primal_out, tangent_nonzeros, tangent_out)]
497500
else:
498-
return maybe_linearize_tracer(self, primal_out, tangent_out)
501+
return maybe_linearize_tracer(self, primal_out, tangent_nonzeros, tangent_out)
499502

500-
def maybe_linearize_tracer(trace, primal, tangent):
501-
if type(tangent) is Zero:
502-
return primal
503-
else:
503+
def maybe_linearize_tracer(trace, primal, is_nonzero, tangent):
504+
if is_nonzero:
505+
assert not type(tangent) is Zero
504506
return LinearizeTracer(trace, primal, tangent)
507+
else:
508+
assert type(tangent) is Zero
509+
return primal
505510

506-
def fallback_linearize_rule(prim, *args, **kwargs):
511+
def fallback_linearize_rule(prim, _, *args, **kwargs):
507512
def call_prim(*args_):
508513
return prim.bind(*args_, **kwargs)
509514
with config.use_direct_linearize(False):
510515
out_primals, out_tangents_pvals, jaxpr, consts, *_maybe_aux = linearize(
511516
lu.wrap_init(call_prim), *args, **kwargs)
512-
def linearized(*tangents):
513-
tangents_out = iter(core.eval_jaxpr(jaxpr, consts, *tangents))
517+
def linearized(residuals, *tangents):
518+
tangents_out = iter(core.eval_jaxpr(jaxpr, residuals, *tangents))
514519
full_out = [pval.get_known() if pval.is_known() else next(tangents_out)
515520
for pval in out_tangents_pvals]
516521
assert next(tangents_out, None) is None
517522
return full_out
518-
return out_primals, linearized
523+
return out_primals, [True for _ in out_primals], consts, linearized
519524

520525
class LinearizeTracer(Tracer):
521526
__slots__ = ['primal', 'tangent']
@@ -547,7 +552,7 @@ def to_concrete_value(self):
547552
primitive_transposes: dict[core.Primitive, Callable] = {}
548553
# transpose rules that internally perform reductions over the given named axes
549554
reducing_transposes: dict[core.Primitive, Callable] = {}
550-
primitive_linearizations: dict[core.Primitive, Callable] = {}
555+
primitive_linearizations : dict[core.Primitive, Callable] = {}
551556

552557
def deflinear(primitive, transpose_rule):
553558
primitive_jvps[primitive] = partial(linear_jvp, primitive)

jax/_src/lax/lax.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2400,12 +2400,16 @@ def _sin_lowering(ctx, x):
24002400
return sine(ctx, x)
24012401
return _nary_lower_hlo(hlo.sine, ctx, x)
24022402

2403+
def _sin_p_lin(_, x):
2404+
cos_x = cos(x) # TODO: allow this to happen in the linearized computation (need to fix backward_pass)
2405+
return (sin_p.bind(x), True, cos_x, lambda cos_x_, t: mul(t, cos_x_))
2406+
24032407
sin_p = standard_unop(_float | _complex, 'sin')
24042408
ad.defjvp(sin_p, lambda g, x: mul(g, cos(x)))
2409+
ad.primitive_linearizations[sin_p] = _sin_p_lin
24052410
mlir.register_lowering(sin_p, _sin_lowering)
24062411
batching.ragged_prop_rules[sin_p] = batching.ragged_mask_elementwise_rule
24072412

2408-
24092413
def _cos_complex(x):
24102414
# cos(x) = complex(cos(real(x)) * cosh(imag(x)), -sin(real(x)) * sinh(imag(x)))
24112415
# see also _sin_complex

tests/api_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4818,7 +4818,7 @@ def check_invariant_to_use_direct_linearize(f):
48184818
self.assertEqual(ans1, ans2)
48194819

48204820
def sin_of_sin(x):
4821-
return jnp.sin(jnp.sin(x))
4821+
return lax.sin(lax.sin(x))
48224822

48234823
check_invariant_to_use_direct_linearize(lambda: jax.grad(sin_of_sin)(1.0))
48244824

0 commit comments

Comments
 (0)