-
Notifications
You must be signed in to change notification settings - Fork 3.3k
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Description
jax.block_until_ready() works on arbitrary pytrees, and will silently ignore any object in the pytree which is not a jax array.
This can lead to surprises if something is not registered as a pytree by mistake or the wrong object is passed to the function: in that case, block_until_ready will silently noop. It took me a while to figure out a bug that was caused by something like the following:
import jax
class Foo: # forgot to register this as pytree, oops
def __init__(self):
self.x = ... # some expensive jax computation
jax.block_until_ready(Foo()) # this silently does not block
# unexpected behavior downstream of here...I think this behavior should be opt-in rather than the default (like in the original version #8536), something like:
jax.block_until_ready(object()) # DeprecationWarning
jax.block_until_ready(object(), strict=True) # raise exception
jax.block_until_ready(object(), strict=False) # okSystem info (python version, jaxlib version, accelerator, etc.)
jax: 0.6.1
jaxlib: 0.6.1
numpy: 2.2.6
python: 3.12.11 | packaged by conda-forge | (main, Jun 4 2025, 14:38:53) [Clang 18.1.8 ]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Darwin', release='24.5.0', version='Darwin Kernel Version 24.5.0: Tue Apr 22 19:54:49 PDT 2025; root:xnu-11417.121.6~2/RELEASE_ARM64_T6000', machine='arm64')
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working