Skip to content

Commit ef68063

Browse files
Merge pull request jax-ml#27809 from mattjj:26621
PiperOrigin-RevId: 745212009
2 parents 29cb6cd + ae95797 commit ef68063

File tree

2 files changed

+35
-3
lines changed

2 files changed

+35
-3
lines changed

jax/experimental/shard_map.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1145,6 +1145,14 @@ def _pbroadcast_batcher(vals_in, dims_in, *, axes, axis_index_groups):
11451145
register_standard_check = \
11461146
lambda prim: _check_rules.setdefault(prim, partial(_standard_check, prim))
11471147

1148+
def _eq_rep(mesh, r1, r2) -> bool:
1149+
if r1 != r2 and r1 is None or r2 is None:
1150+
r1, r2 = _remove_none_rep(mesh, r1), _remove_none_rep(mesh, r2)
1151+
return r1 == r2
1152+
1153+
def _remove_none_rep(mesh, r):
1154+
return set(mesh.axis_names) if r is None else r
1155+
11481156
def _no_rewrite(prim, rule, mesh, in_rep, *args, **params):
11491157
out_vals = prim.bind(*args,**params)
11501158
out_rep = rule(mesh, *in_rep, **params)
@@ -1371,7 +1379,7 @@ def _scan_check(mesh, *in_rep, jaxpr, num_consts, num_carry, **_):
13711379
_, carry_rep_in, _ = split_list(in_rep, [num_consts, num_carry])
13721380
out_rep = _check_rep(mesh, jaxpr.jaxpr, in_rep)
13731381
carry_rep_out, _ = split_list(out_rep, [num_carry])
1374-
if carry_rep_in != carry_rep_out:
1382+
if not all(map(partial(_eq_rep, mesh), carry_rep_in, carry_rep_out)):
13751383
raise Exception("Scan carry input and output got mismatched replication "
13761384
f"types {carry_rep_in} and {carry_rep_out}. Please open an "
13771385
"issue at https://github.com/jax-ml/jax/issues, and as a "
@@ -1447,7 +1455,7 @@ def _cond_rule(mesh, *in_rep, branches):
14471455
out_rep = _check_rep(mesh, branches[0].jaxpr, args_rep)
14481456
for branch in branches[1:]:
14491457
out_rep_ = _check_rep(mesh, branch.jaxpr, args_rep)
1450-
if out_rep_ != out_rep:
1458+
if not all(map(partial(_eq_rep, mesh), out_rep, out_rep_)):
14511459
raise Exception("The branches of cond produced mismatched replication "
14521460
"types. Please open an issue at "
14531461
"https://github.com/jax-ml/jax/issues, and as a "
@@ -2189,7 +2197,7 @@ def _efficient_transpose_rewrite(fun, mesh, in_names, out_names_thunk):
21892197
def _efficient_transpose_rewrite_nomatch(f, store, mesh, in_reps, *args):
21902198
with core.take_current_trace() as parent:
21912199
tag = core.TraceTag()
2192-
t = RewriteTrace(parent_trace = parent, tag = tag, mesh=mesh)
2200+
t = RewriteTrace(parent_trace=parent, tag=tag, mesh=mesh)
21932201
in_tracers = map(partial(RewriteTracer, t), in_reps, args)
21942202
with core.set_current_trace(t):
21952203
ans = f(*in_tracers)

tests/shard_map_test.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2798,6 +2798,30 @@ def f(x):
27982798

27992799
f(x) # doesn't crash
28002800

2801+
def test_rep_none_canonicalization(self):
2802+
# https://github.com/jax-ml/jax/issues/26621
2803+
if config.use_shardy_partitioner.value:
2804+
self.skipTest('complex values fail under shardy')
2805+
N = 8
2806+
xs = jnp.ones((8, N), dtype=jnp.int32)
2807+
variables = jax.random.normal(jax.random.key(1), (N, N), jnp.complex64)
2808+
mesh = jtu.create_mesh((2,), ('i',))
2809+
in_specs = (P(), P("i"),)
2810+
out_specs = P("i")
2811+
2812+
variables = jax.lax.with_sharding_constraint(variables, NamedSharding(mesh, P()))
2813+
xs = jax.lax.with_sharding_constraint(xs, NamedSharding(mesh, P('i')))
2814+
2815+
def fun(v, xs):
2816+
# Commenting this single line below makes everything work
2817+
v = jax.scipy.linalg.expm(v)
2818+
v = v.sum()
2819+
return v * xs.sum(axis=-1).astype(v.dtype)
2820+
2821+
res = fun(variables, xs)
2822+
fun_shard_map = shard_map(fun, mesh=mesh, in_specs=in_specs, out_specs=out_specs)
2823+
res = fun_shard_map(variables, xs) # don't crash
2824+
28012825

28022826
class FunSpec(NamedTuple):
28032827
name: str

0 commit comments

Comments
 (0)