@@ -277,9 +277,6 @@ def write_primal(v, val):
277277 call_jaxpr = params .pop ('call_jaxpr' )
278278 cts_out = get_primitive_transpose (eqn .primitive )(
279279 params , call_jaxpr , invals , cts_in , cts_in_avals )
280- elif eqn .primitive in reducing_transposes :
281- cts_out = reducing_transposes [eqn .primitive ](
282- cts_in , * invals , ** eqn .params )
283280 else :
284281 cts_out = get_primitive_transpose (eqn .primitive )(
285282 cts_in , * invals , ** eqn .params )
@@ -586,8 +583,6 @@ def to_concrete_value(self):
586583
587584primitive_jvps : dict [core .Primitive , Callable ] = {}
588585primitive_transposes : dict [core .Primitive , Callable ] = {}
589- # transpose rules that internally perform reductions over the given named axes
590- reducing_transposes : dict [core .Primitive , Callable ] = {}
591586primitive_linearizations : dict [core .Primitive , Callable ] = {}
592587
593588def deflinear (primitive , transpose_rule ):
@@ -871,3 +866,6 @@ def __init__(self):
871866 "closed-over value into the custom_vjp function as an argument, and "
872867 "adapting the custom_vjp fwd and bwd rules." )
873868 super ().__init__ (msg )
869+
870+ # TODO(mattjj): remove this vestigial dict
871+ reducing_transposes : dict [core .Primitive , Callable ] = {}
0 commit comments