Skip to content

Commit 5486a63

Browse files
danielsuoFlax Authors
authored andcommitted
No public description
PiperOrigin-RevId: 877435039
1 parent c860bee commit 5486a63

File tree

3 files changed

+22
-27
lines changed

3 files changed

+22
-27
lines changed

examples/wmt/train.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -363,10 +363,9 @@ def pre_pmap(xs):
363363
def post_pmap(xs):
364364
# Avoid degraded performance under the new jax.pmap. See
365365
# https://docs.jax.dev/en/latest/migrate_pmap.html#int-indexing-into-sharded-arrays.
366-
if jax.config.jax_pmap_shmap_merge:
367-
return jax.tree_util.tree_map(
368-
lambda x: x.addressable_shards[0].data.squeeze(0), xs)
369-
return jax.tree_util.tree_map(lambda x: x[0], xs)
366+
return jax.tree_util.tree_map(
367+
lambda x: x.addressable_shards[0].data.squeeze(0), xs
368+
)
370369

371370
return post_pmap(host_psum(pre_pmap(in_tree)))
372371

flax/jax_utils.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -47,22 +47,22 @@ def replicate(tree, devices=None):
4747

4848
def unreplicate(tree):
4949
"""Returns a single instance of a replicated array."""
50-
if jax.config.jax_pmap_shmap_merge:
51-
def _unreplicate_one(x):
52-
# Avoid degraded performance under the new jax.pmap. See
53-
# https://docs.jax.dev/en/latest/migrate_pmap.html#int-indexing-into-sharded-arrays.
54-
# Handle 0-dimensional (scalar) arrays - cannot index into them
55-
if hasattr(x, 'ndim') and x.ndim == 0:
56-
return x
57-
if (not hasattr(x, 'sharding') or
58-
isinstance(x.sharding, jax.sharding.SingleDeviceSharding) or
59-
len(jax.local_devices()) == 1):
60-
return x[0]
61-
if x.sharding.is_fully_replicated:
62-
return x.addressable_shards[0].data
63-
return x.addressable_shards[0].data.squeeze(0)
64-
return jax.tree_util.tree_map(_unreplicate_one, tree)
65-
return jax.tree_util.tree_map(lambda x: x[0], tree)
50+
def _unreplicate_one(x):
51+
# Avoid degraded performance under the new jax.pmap.
52+
# Handle 0-dimensional (scalar) arrays - cannot index into them
53+
if hasattr(x, 'ndim') and x.ndim == 0:
54+
return x
55+
if (
56+
not hasattr(x, 'sharding')
57+
or isinstance(x.sharding, jax.sharding.SingleDeviceSharding)
58+
or len(jax.local_devices()) == 1
59+
):
60+
return x[0]
61+
if x.sharding.is_fully_replicated:
62+
return x.addressable_shards[0].data
63+
return x.addressable_shards[0].data.squeeze(0)
64+
65+
return jax.tree_util.tree_map(_unreplicate_one, tree)
6666

6767

6868
def pmean(xs, axis_name):

flax/training/common_utils.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -80,13 +80,9 @@ def get_metrics(device_metrics):
8080
"""
8181
# We select the first element of x in order to get a single copy of a
8282
# device-replicated metric.
83-
# Avoid degraded performance under the new jax.pmap. See
84-
# https://docs.jax.dev/en/latest/migrate_pmap.html#int-indexing-into-sharded-arrays.
85-
if jax.config.jax_pmap_shmap_merge:
86-
device_metrics = jax.tree_util.tree_map(
87-
lambda x: x.addressable_shards[0].data.squeeze(0), device_metrics)
88-
else:
89-
device_metrics = jax.tree_util.tree_map(lambda x: x[0], device_metrics)
83+
device_metrics = jax.tree_util.tree_map(
84+
lambda x: x.addressable_shards[0].data.squeeze(0), device_metrics
85+
)
9086
metrics_np = jax.device_get(device_metrics)
9187
return stack_forest(metrics_np)
9288

0 commit comments

Comments
 (0)