Skip to content
Discussion options

You must be logged in to vote

When you do assert xyz or if xyz, Python will call the __bool__ method of the object xyz.

In the case of two Device Arrays x and y, the statement z = x < y will return another Device Array of type bool. The __bool__ method for DeviceArrays is defined here: https://github.com/google/jax/blob/b3a62cd3f2be15a7ed23771b371835e2977961be/jax/_src/device_array.py#L269

This currently will forward z._value.__bool__, and z._value involves converting the buffer to a py buffer on the CPU: https://github.com/google/jax/blob/b3a62cd3f2be15a7ed23771b371835e2977961be/jax/_src/device_array.py#L144-L150

I don't think there's any way around this: Python control flow happens on the CPU, because Python runs on…

Replies: 1 comment 2 replies

Comment options

You must be logged in to vote
2 replies
@luweizheng
Comment options

@jakevdp
Comment options

Answer selected by luweizheng
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
2 participants