NCCL4JAX - Using JIT on the allreduce #11333
Unanswered
sophia-estrela
asked this question in
Q&A
Replies: 2 comments 2 replies
-
JAX JIT will replace all non-static arguments with Tracers, and Tracers cannot be converted into |
Beta Was this translation helpful? Give feedback.
2 replies
-
This isn't an answer to your question, but are you aware that JAX already supports using NCCL on GPU, including across hosts? See |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi! I am currently building an NCCL library for JAX, more specifically I am finishing writing the all reduce operation right now.
I already finished the XLA part and I am now trying to make the function work with JIT.
I have something like above where count, dataType, and op, are integers, sendbuffer is an array (that will be put on device), and comm is the communicator (object of the nccl communicator class).
But when I try to run I get the following error:
When I remove the variables count, dataType, and op it works. I do not understand why this does not work, since they are simple integers at start, and the conversion to DynamicJaxprTracer is done by the jit itself.
Can someone explain to me what is wrong or why this is returning an error?
Github with the code
Thank you for the help!
Beta Was this translation helpful? Give feedback.
All reactions