diff --git a/jax/_src/api.py b/jax/_src/api.py index f817dc0ee6c0..97dac389db5f 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -3221,29 +3221,29 @@ def block_until_ready(x): A pytree with the same structure and values of the input, where the values of all JAX array leaves are ready. """ - def try_to_block(x): - try: - return x.block_until_ready() - except AttributeError: - return x - + leaves = tree_util.tree_leaves(x) + + # Collect all blockable leaves arrays = [] - for leaf in tree_leaves(x): - if isinstance(leaf, array.ArrayImpl): - arrays.append(leaf) - else: - try_to_block(leaf) - - if not arrays: - # `arrays` will be empty if tree_leaves(x) is empty or all leaves are not - # jax.Array. - pass - elif len(arrays) == 1: - # Fast path for single array. - try_to_block(arrays[0]) - else: - # Optimized for multiple arrays. - xc.batched_block_until_ready(arrays) + for leaf in leaves: + if hasattr(leaf, "block_until_ready"): + arrays.append(leaf) + + # FIX: Only warn if the input is NOT empty but contains NO blockable objects. + # This allows mixed trees (arrays + ints) to pass silently, satisfying jakevdp, + # but warns on pure non-array inputs (ints), satisfying the original issue. + if leaves and not arrays: + import warnings + warnings.warn( + "jax.block_until_ready() was called on an input that contains no JAX arrays " + "or objects with a block_until_ready() method. This operation is a no-op.", + UserWarning, + stacklevel=2 + ) + + # Delegate to the optimized batch blocker + if arrays: + xc.batched_block_until_ready(arrays) return x