@@ -1145,6 +1145,14 @@ def _pbroadcast_batcher(vals_in, dims_in, *, axes, axis_index_groups):
11451145register_standard_check = \
11461146 lambda prim : _check_rules .setdefault (prim , partial (_standard_check , prim ))
11471147
1148+ def _eq_rep (mesh , r1 , r2 ) -> bool :
1149+ if r1 != r2 and r1 is None or r2 is None :
1150+ r1 , r2 = _remove_none_rep (mesh , r1 ), _remove_none_rep (mesh , r2 )
1151+ return r1 == r2
1152+
1153+ def _remove_none_rep (mesh , r ):
1154+ return set (mesh .axis_names ) if r is None else r
1155+
11481156def _no_rewrite (prim , rule , mesh , in_rep , * args , ** params ):
11491157 out_vals = prim .bind (* args ,** params )
11501158 out_rep = rule (mesh , * in_rep , ** params )
@@ -1371,7 +1379,7 @@ def _scan_check(mesh, *in_rep, jaxpr, num_consts, num_carry, **_):
13711379 _ , carry_rep_in , _ = split_list (in_rep , [num_consts , num_carry ])
13721380 out_rep = _check_rep (mesh , jaxpr .jaxpr , in_rep )
13731381 carry_rep_out , _ = split_list (out_rep , [num_carry ])
1374- if carry_rep_in != carry_rep_out :
1382+ if not all ( map ( partial ( _eq_rep , mesh ), carry_rep_in , carry_rep_out )) :
13751383 raise Exception ("Scan carry input and output got mismatched replication "
13761384 f"types { carry_rep_in } and { carry_rep_out } . Please open an "
13771385 "issue at https://github.com/jax-ml/jax/issues, and as a "
@@ -1447,7 +1455,7 @@ def _cond_rule(mesh, *in_rep, branches):
14471455 out_rep = _check_rep (mesh , branches [0 ].jaxpr , args_rep )
14481456 for branch in branches [1 :]:
14491457 out_rep_ = _check_rep (mesh , branch .jaxpr , args_rep )
1450- if out_rep_ != out_rep :
1458+ if not all ( map ( partial ( _eq_rep , mesh ), out_rep , out_rep_ )) :
14511459 raise Exception ("The branches of cond produced mismatched replication "
14521460 "types. Please open an issue at "
14531461 "https://github.com/jax-ml/jax/issues, and as a "
@@ -2189,7 +2197,7 @@ def _efficient_transpose_rewrite(fun, mesh, in_names, out_names_thunk):
21892197def _efficient_transpose_rewrite_nomatch (f , store , mesh , in_reps , * args ):
21902198 with core .take_current_trace () as parent :
21912199 tag = core .TraceTag ()
2192- t = RewriteTrace (parent_trace = parent , tag = tag , mesh = mesh )
2200+ t = RewriteTrace (parent_trace = parent , tag = tag , mesh = mesh )
21932201 in_tracers = map (partial (RewriteTracer , t ), in_reps , args )
21942202 with core .set_current_trace (t ):
21952203 ans = f (* in_tracers )
0 commit comments