@@ -3221,29 +3221,29 @@ def block_until_ready(x):
32213221 A pytree with the same structure and values of the input, where the values
32223222 of all JAX array leaves are ready.
32233223 """
3224- def try_to_block (x ):
3225- try :
3226- return x .block_until_ready ()
3227- except AttributeError :
3228- return x
3229-
3224+ leaves = tree_util .tree_leaves (x )
3225+
3226+ # Collect all blockable leaves
32303227 arrays = []
3231- for leaf in tree_leaves (x ):
3232- if isinstance (leaf , array .ArrayImpl ):
3233- arrays .append (leaf )
3234- else :
3235- try_to_block (leaf )
3236-
3237- if not arrays :
3238- # `arrays` will be empty if tree_leaves(x) is empty or all leaves are not
3239- # jax.Array.
3240- pass
3241- elif len (arrays ) == 1 :
3242- # Fast path for single array.
3243- try_to_block (arrays [0 ])
3244- else :
3245- # Optimized for multiple arrays.
3246- xc .batched_block_until_ready (arrays )
3228+ for leaf in leaves :
3229+ if hasattr (leaf , "block_until_ready" ):
3230+ arrays .append (leaf )
3231+
3232+ # FIX: Only warn if the input is NOT empty but contains NO blockable objects.
3233+ # This allows mixed trees (arrays + ints) to pass silently, satisfying jakevdp,
3234+ # but warns on pure non-array inputs (ints), satisfying the original issue.
3235+ if leaves and not arrays :
3236+ import warnings
3237+ warnings .warn (
3238+ "jax.block_until_ready() was called on an input that contains no JAX arrays "
3239+ "or objects with a block_until_ready() method. This operation is a no-op." ,
3240+ UserWarning ,
3241+ stacklevel = 2
3242+ )
3243+
3244+ # Delegate to the optimized batch blocker
3245+ if arrays :
3246+ xc .batched_block_until_ready (arrays )
32473247
32483248 return x
32493249
0 commit comments