Get channel_id in HLO dump #22075
Unanswered
DhruvKumar1
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
I'm interested in looking at the HLO dumps from XLA when running a JAX program. I was looking through some of the multi-gpu operations like all_gather and all_reduce. I set
--xla_dump_to=
and am able to see the HLO dumps (before/after optimizations). For the all_reduce I see replica groups and the size of tensors. However, one thing that I am not able to see is thechannel_id
that is mentioned as an optional argument in the XLA Operation Semantics.I have seen online some HLO snippets where there is a
channel_id
populated in theall-reduce-start
. Could someone please explain when this field is populated, and a possible simple python example that would induce an HLO with an all_gather with this field populated?This is the simple code that I am running that induces an all_gather and an all_reduce. (running on 2 GPUs)
this is the output I am getting for the module relating to the all_reduce
Beta Was this translation helpful? Give feedback.
All reactions