Skip to content

Commit 97b1faa

Browse files
committed
Fixes the random key sharding in shard_map.
1 parent 879fa12 commit 97b1faa

File tree

2 files changed

+20
-1
lines changed

2 files changed

+20
-1
lines changed

jax/experimental/shard_map.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -724,7 +724,6 @@ def _xla_shard(ctx: mlir.LoweringRuleContext, mesh, auto, names,
724724
aval_in, aval_out, x):
725725
if prod([size for n, size in mesh.shape.items() if n not in auto]) == 1:
726726
return x
727-
manual_proto = pxla.manual_proto(aval_in, frozenset(mesh.axis_names) - auto, mesh)
728727
axes = {name: i for i, ns in names.items() for name in ns}
729728
ns = _make_scoped_manual_sharding(ctx, mesh, axes)
730729
if dtypes.issubdtype(aval_in.dtype, dtypes.extended):
@@ -734,6 +733,7 @@ def _xla_shard(ctx: mlir.LoweringRuleContext, mesh, auto, names,
734733
unspecified = set(range(aval_in.ndim)) if auto else set()
735734
sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, shard_proto,
736735
unspecified_dims=unspecified)
736+
manual_proto = pxla.manual_proto(aval_in, frozenset(mesh.axis_names) - auto, mesh)
737737
return mlir.wrap_with_full_to_shard_op(ctx, sx, aval_out, manual_proto, unspecified)
738738

739739
def _xla_unshard(ctx: mlir.LoweringRuleContext, mesh, auto, names,
@@ -746,6 +746,8 @@ def _xla_unshard(ctx: mlir.LoweringRuleContext, mesh, auto, names,
746746
ns = sharding_impls.physical_sharding(aval_out, ns)
747747
aval_out = core.physical_aval(aval_out)
748748
unspecified = set(range(aval_out.ndim)) if auto else set()
749+
if dtypes.issubdtype(aval_in.dtype, dtypes.extended):
750+
aval_in = core.physical_aval(aval_in)
749751
manual_proto = pxla.manual_proto(aval_in, frozenset(mesh.axis_names) - auto, mesh)
750752
sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, manual_proto, unspecified_dims=unspecified)
751753
shard_proto = ns._to_xla_hlo_sharding(aval_out.ndim).to_proto()

tests/shard_map_test.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2207,6 +2207,23 @@ def f(x):
22072207
#
22082208
# f(x) # don't crash
22092209

2210+
def test_partial_auto_of_random_keys(self):
2211+
if config.use_shardy_partitioner.value:
2212+
self.skipTest('Shardy does not support full-to-shard.')
2213+
2214+
mesh = jtu.create_mesh((4, 2), ('i', 'j'))
2215+
keys = jax.random.split(jax.random.key(0), 8)
2216+
2217+
@jax.jit
2218+
def f(x):
2219+
return shard_map(lambda k: k,
2220+
mesh, in_specs=P('i'), out_specs=P('i'),
2221+
check_rep=False, auto=frozenset({'j'}))(keys)
2222+
2223+
y = f(keys) # don't crash
2224+
self.assertAllClose(jax.random.key_data(y), jax.random.key_data(keys),
2225+
check_dtypes=False)
2226+
22102227
def test_vmap_grad_shmap_spmd_axis_name_residuals(self):
22112228
# https://github.com/jax-ml/jax/pull/21032
22122229
mesh = jtu.create_mesh((4, 2), ('i', 'j'))

0 commit comments

Comments
 (0)