Skip to content
Discussion options

You must be logged in to vote

Which jax version are you using?

I can run this code successfully on my end with following stdout (I'm using jax 0.3.15):

[GpuDevice(id=0, process_index=0), GpuDevice(id=1, process_index=0)]
[[0 1 2 3]
 [4 5 6 7]] [DeviceArray([[0, 1, 2, 3],
             [4, 5, 6, 7]], dtype=int32), DeviceArray([[0, 1, 2, 3],
             [4, 5, 6, 7]], dtype=int32)]
# pip list | grep jax
jax                           0.3.15
jaxlib                        0.3.15

Replies: 1 comment 2 replies

Comment options

You must be logged in to vote
2 replies
@coderlemon17
Comment options

@coderlemon17
Comment options

Answer selected by coderlemon17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants