Skip to content

Commit 0d06731

Browse files
Merge pull request jax-ml#27797 from mattjj:shmap-fix
PiperOrigin-RevId: 744827753
2 parents 522add2 + 05ca023 commit 0d06731

File tree

2 files changed

+15
-1
lines changed

2 files changed

+15
-1
lines changed

jax/experimental/shard_map.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -958,7 +958,8 @@ def process_primitive(self, prim, tracers, params):
958958
rep_rule = _check_rules.get(prim, partial(_rule_missing, prim))
959959
out_rep = rep_rule(self.mesh, *in_rep, **params) if self.check else set()
960960
if prim.multiple_results:
961-
out_rep = [out_rep] * len(out_vals) if type(out_rep) is set else out_rep
961+
out_rep = (out_rep if isinstance(out_rep, (list, tuple))
962+
else [out_rep] * len(out_vals))
962963
return map(partial(ShardMapTracer, self), out_rep, out_vals)
963964
return ShardMapTracer(self, out_rep, out_vals)
964965

tests/shard_map_test.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -685,6 +685,19 @@ def f3():
685685
f3()
686686
jax.jit(f3)()
687687

688+
def test_multiple_result_primitive_with_none_sharding(self):
689+
# https://github.com/jax-ml/jax/issues/27673
690+
xs = jnp.arange(20).reshape(2, 10)
691+
mesh = jtu.create_mesh((2,), ("i",))
692+
y = shard_map(
693+
lambda x: jnp.split(x.squeeze(), 2),
694+
mesh=mesh,
695+
in_specs=(None,),
696+
out_specs=P("i"),
697+
)(xs)
698+
expected = jnp.repeat(xs, 2, axis=0).reshape(2, 2, 10)
699+
self.assertArraysEqual(y, expected)
700+
688701
def test_vmap_spmd_axis_name(self):
689702
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
690703

0 commit comments

Comments
 (0)