Skip to content

Commit 39d73a6

Browse files
Merge pull request jax-ml#25276 from mattjj:remove-vestigial-reducing-transposes
PiperOrigin-RevId: 703081286
2 parents 03861d4 + 6172a1f commit 39d73a6

File tree

5 files changed

+7
-9
lines changed

5 files changed

+7
-9
lines changed

jax/_src/ad_checkpoint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -652,7 +652,7 @@ def remat_transpose(out_cts, *in_primals, jaxpr, **params):
652652
for x in in_primals]
653653
assert next(in_cts_nz_, None) is next(in_zeros_, None) is None
654654
return in_cts
655-
ad.reducing_transposes[remat_p] = remat_transpose
655+
ad.primitive_transposes[remat_p] = remat_transpose
656656

657657
# TODO(mattjj): move this to ad.py
658658
def transpose_jaxpr(jaxpr: core.ClosedJaxpr, in_linear: bool | Sequence[bool],

jax/_src/interpreters/ad.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -277,9 +277,6 @@ def write_primal(v, val):
277277
call_jaxpr = params.pop('call_jaxpr')
278278
cts_out = get_primitive_transpose(eqn.primitive)(
279279
params, call_jaxpr, invals, cts_in, cts_in_avals)
280-
elif eqn.primitive in reducing_transposes:
281-
cts_out = reducing_transposes[eqn.primitive](
282-
cts_in, *invals, **eqn.params)
283280
else:
284281
cts_out = get_primitive_transpose(eqn.primitive)(
285282
cts_in, *invals, **eqn.params)
@@ -586,8 +583,6 @@ def to_concrete_value(self):
586583

587584
primitive_jvps : dict[core.Primitive, Callable] = {}
588585
primitive_transposes: dict[core.Primitive, Callable] = {}
589-
# transpose rules that internally perform reductions over the given named axes
590-
reducing_transposes: dict[core.Primitive, Callable] = {}
591586
primitive_linearizations : dict[core.Primitive, Callable] = {}
592587

593588
def deflinear(primitive, transpose_rule):
@@ -871,3 +866,6 @@ def __init__(self):
871866
"closed-over value into the custom_vjp function as an argument, and "
872867
"adapting the custom_vjp fwd and bwd rules.")
873868
super().__init__(msg)
869+
870+
# TODO(mattjj): remove this vestigial dict
871+
reducing_transposes: dict[core.Primitive, Callable] = {}

jax/_src/lax/control_flow/conditionals.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -780,7 +780,7 @@ def _cond_typecheck(bind_time, *in_atoms, branches):
780780
cond_p.def_impl(partial(dispatch.apply_primitive, cond_p))
781781
cond_p.def_effectful_abstract_eval(_cond_abstract_eval)
782782
ad.primitive_jvps[cond_p] = _cond_jvp
783-
ad.reducing_transposes[cond_p] = _cond_transpose
783+
ad.primitive_transposes[cond_p] = _cond_transpose
784784
pe.custom_partial_eval_rules[cond_p] = _cond_partial_eval
785785
batching.fancy_primitive_batchers[cond_p] = _cond_batching_rule
786786
xla.register_initial_style_primitive(cond_p)

jax/_src/lax/control_flow/loops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1228,7 +1228,7 @@ def arrange_jaxpr_args_for_wrapped(args):
12281228
scan_p.def_impl(partial(dispatch.apply_primitive, scan_p))
12291229
scan_p.def_effectful_abstract_eval(_scan_abstract_eval)
12301230
ad.primitive_jvps[scan_p] = _scan_jvp
1231-
ad.reducing_transposes[scan_p] = _scan_transpose
1231+
ad.primitive_transposes[scan_p] = _scan_transpose
12321232
pe.custom_partial_eval_rules[scan_p] = _scan_partial_eval
12331233
xla.register_initial_style_primitive(scan_p)
12341234
mlir.register_lowering(scan_p,

jax/_src/pjit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2385,7 +2385,7 @@ def prune_type(ty, xs, maybe_zeros):
23852385
_set_states(attrs_tracked, final_states)
23862386

23872387
return tree_unflatten(cts_out_treedef, nz_cts_out)
2388-
ad.reducing_transposes[pjit_p] = _pjit_transpose
2388+
ad.primitive_transposes[pjit_p] = _pjit_transpose
23892389

23902390

23912391
@weakref_lru_cache

0 commit comments

Comments
 (0)