Skip to content

Commit 05ca023

Browse files
mattjjjustinjfu
andcommitted
[shard-map] in eager shmap, handle all rep rule output cases
By convention, rep_rules can return three kinds of thing: 1. a sequence (tuple or list), 2. a single set, or 3. a single None. Even rules for primitives with multiple results can return single objects rather than sequences; the reason is that it's convenient not ot have to infer the number of outputs for higher-order primitives. In the latter two cases we rely on the caller (in this case, ShardMapTrace.process_primitive) to 'broadcast' the singleton result to a list of results equal to the number of outputs. Previously, the code was checking `if type(out_rep) is set`, which doesn't handle case 3. (We briefly tried another fix direction where we don't allow case 3, because we don't have case 3 in the upcoming VMA type system which replaces this stuff. But until that lands the easiest fix is just to handle all cases correctly.) fixes jax-ml#26148, fixes jax-ml#27673 Co-authored-by: Justin Fu <[email protected]>
1 parent e1e37f8 commit 05ca023

File tree

2 files changed

+15
-1
lines changed

2 files changed

+15
-1
lines changed

jax/experimental/shard_map.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -958,7 +958,8 @@ def process_primitive(self, prim, tracers, params):
958958
rep_rule = _check_rules.get(prim, partial(_rule_missing, prim))
959959
out_rep = rep_rule(self.mesh, *in_rep, **params) if self.check else set()
960960
if prim.multiple_results:
961-
out_rep = [out_rep] * len(out_vals) if type(out_rep) is set else out_rep
961+
out_rep = (out_rep if isinstance(out_rep, (list, tuple))
962+
else [out_rep] * len(out_vals))
962963
return map(partial(ShardMapTracer, self), out_rep, out_vals)
963964
return ShardMapTracer(self, out_rep, out_vals)
964965

tests/shard_map_test.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -685,6 +685,19 @@ def f3():
685685
f3()
686686
jax.jit(f3)()
687687

688+
def test_multiple_result_primitive_with_none_sharding(self):
689+
# https://github.com/jax-ml/jax/issues/27673
690+
xs = jnp.arange(20).reshape(2, 10)
691+
mesh = jtu.create_mesh((2,), ("i",))
692+
y = shard_map(
693+
lambda x: jnp.split(x.squeeze(), 2),
694+
mesh=mesh,
695+
in_specs=(None,),
696+
out_specs=P("i"),
697+
)(xs)
698+
expected = jnp.repeat(xs, 2, axis=0).reshape(2, 2, 10)
699+
self.assertArraysEqual(y, expected)
700+
688701
def test_vmap_spmd_axis_name(self):
689702
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
690703

0 commit comments

Comments
 (0)