Skip to content

Commit f45cbf3

Browse files
yashk2810Google-ML-Automation
authored andcommitted
Fix a bug where full and use_mesh outside jit did not work because the shard passed to make_array_from_callback was sharded on all devices instead of just 1 device.
This is because `convert_element_type` returning an output on all devices of the mesh because of the surrounding `use_mesh` context. PiperOrigin-RevId: 735909962
1 parent 29bfd00 commit f45cbf3

File tree

1 file changed

+1
-0
lines changed

1 file changed

+1
-0
lines changed

jax/_src/lax/lax.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3016,6 +3016,7 @@ def full(shape: Shape, fill_value: ArrayLike, dtype: DTypeLike | None = None, *,
30163016
isinstance(fill_value, array.ArrayImpl) and sharding._is_concrete):
30173017
broadcast_shape = sharding.shard_shape(shape)
30183018
shard = broadcast(fill_value, broadcast_shape)
3019+
shard = shard.addressable_data(0)
30193020
return array.make_array_from_callback(shape, sharding, lambda _: shard)
30203021

30213022
if sharding is not None and not sharding._is_concrete:

0 commit comments

Comments
 (0)