Skip to content

Commit e210507

Browse files
danielsuocopybara-github
authored andcommitted
[pmap] Handle edge cases in get_from_first_device for jax_pmap_shmap_merge.
pmap replication/sharding semantics are not consistent, so different users of the acme library do different things. The real solution is to migrate away from pmap entirely, but IIUC, there may not be staffing for this. Add a robust _unreplicate helper function that handles: - 0-dimensional scalar arrays (cannot be indexed) - SingleDeviceSharding or single local device (use simple indexing) - Fully replicated arrays (return addressable shard data without squeeze) - Default sharded case (squeeze the leading replicated dimension) This prevents errors when processing nested structures containing arrays with different sharding configurations under jax_pmap_shmap_merge mode. PiperOrigin-RevId: 864435972 Change-Id: I6e151cbe0ed771f62762ec4f2d8e11a972e2a4ac
1 parent 05ced10 commit e210507

File tree

1 file changed

+15
-3
lines changed

1 file changed

+15
-3
lines changed

acme/jax/utils.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -387,9 +387,21 @@ def get_from_first_device(nest: N, as_numpy: bool = True) -> N:
387387
# Avoid degraded performance under the new jax.pmap. See
388388
# https://docs.jax.dev/en/latest/migrate_pmap.html#int-indexing-into-sharded-arrays.
389389
if jax.config.jax_pmap_shmap_merge:
390-
zeroth_nest = jax.tree_util.tree_map(
391-
lambda x: x.addressable_shards[0].data.squeeze(0), nest
392-
)
390+
391+
def _unreplicate(x):
392+
if hasattr(x, 'ndim') and x.ndim == 0:
393+
return x
394+
if not hasattr(x, 'sharding') or isinstance(
395+
x.sharding, jax.sharding.SingleDeviceSharding
396+
):
397+
return x
398+
if len(jax.local_devices()) == 1:
399+
return x[0]
400+
if x.sharding.is_fully_replicated:
401+
return x.addressable_shards[0].data
402+
return x.addressable_shards[0].data.squeeze(0)
403+
404+
zeroth_nest = jax.tree_util.tree_map(_unreplicate, nest)
393405
else:
394406
zeroth_nest = jax.tree_util.tree_map(lambda x: x[0], nest)
395407
return jax.device_get(zeroth_nest) if as_numpy else zeroth_nest

0 commit comments

Comments
 (0)