Skip to content
Discussion options

You must be logged in to vote

In this case, since you explicitly specified the device of A and did not explicitly specify the device of b, array A is committed to the device and b is uncommitted. So if you do a simple operation like A + b (assuming shapes are compatible) then the computation will happen on the device were A resides. For example:

>>> import jax
>>> A = jax.numpy.ones((4, 4), device=jax.devices('cpu')[0])
>>> print(A.device, A.committed)
TFRT_CPU_0 True

>>> b = jax.numpy.arange(4.0)
>>> print(b.device, b.committed)
cuda:0 False

>>> c = A + b
>>> print(c.device, c.committed)
TFRT_CPU_0 True

Now, if I have a JIT-compiled function which takes A and b as arguments, where will the computations happen?

Fo…

Replies: 2 comments 2 replies

Comment options

You must be logged in to vote
1 reply
@jakevdp
Comment options

Answer selected by YigitElma
Comment options

You must be logged in to vote
1 reply
@jakevdp
Comment options

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants