Skip to content

jax.block_until_ready() should error on non-arrays by default #29744

@Gattocrucco

Description

@Gattocrucco

Description

jax.block_until_ready() works on arbitrary pytrees, and will silently ignore any object in the pytree which is not a jax array.

This can lead to surprises if something is not registered as a pytree by mistake or the wrong object is passed to the function: in that case, block_until_ready will silently noop. It took me a while to figure out a bug that was caused by something like the following:

import jax

class Foo:  # forgot to register this as pytree, oops
    def __init__(self):
        self.x = ...  # some expensive jax computation

jax.block_until_ready(Foo())  # this silently does not block

# unexpected behavior downstream of here...

I think this behavior should be opt-in rather than the default (like in the original version #8536), something like:

jax.block_until_ready(object())  # DeprecationWarning
jax.block_until_ready(object(), strict=True)  # raise exception
jax.block_until_ready(object(), strict=False)  # ok

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.6.1
jaxlib: 0.6.1
numpy:  2.2.6
python: 3.12.11 | packaged by conda-forge | (main, Jun  4 2025, 14:38:53) [Clang 18.1.8 ]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Darwin', release='24.5.0', version='Darwin Kernel Version 24.5.0: Tue Apr 22 19:54:49 PDT 2025; root:xnu-11417.121.6~2/RELEASE_ARM64_T6000', machine='arm64')

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions