Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 22 additions & 22 deletions jax/_src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3221,29 +3221,29 @@ def block_until_ready(x):
A pytree with the same structure and values of the input, where the values
of all JAX array leaves are ready.
"""
def try_to_block(x):
try:
return x.block_until_ready()
except AttributeError:
return x

leaves = tree_util.tree_leaves(x)

# Collect all blockable leaves
arrays = []
for leaf in tree_leaves(x):
if isinstance(leaf, array.ArrayImpl):
arrays.append(leaf)
else:
try_to_block(leaf)

if not arrays:
# `arrays` will be empty if tree_leaves(x) is empty or all leaves are not
# jax.Array.
pass
elif len(arrays) == 1:
# Fast path for single array.
try_to_block(arrays[0])
else:
# Optimized for multiple arrays.
xc.batched_block_until_ready(arrays)
for leaf in leaves:
if hasattr(leaf, "block_until_ready"):
arrays.append(leaf)

# FIX: Only warn if the input is NOT empty but contains NO blockable objects.
# This allows mixed trees (arrays + ints) to pass silently, satisfying jakevdp,
# but warns on pure non-array inputs (ints), satisfying the original issue.
if leaves and not arrays:
import warnings
warnings.warn(
"jax.block_until_ready() was called on an input that contains no JAX arrays "
"or objects with a block_until_ready() method. This operation is a no-op.",
UserWarning,
stacklevel=2
)

# Delegate to the optimized batch blocker
if arrays:
xc.batched_block_until_ready(arrays)

return x

Expand Down