add guide for SGLang-Jax on TPUs#103
add guide for SGLang-Jax on TPUs#103JamesBrianD wants to merge 1 commit intoAI-Hypercomputer:mainfrom
Conversation
|
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). View this failed invocation of the CLA check for more information. For the most up to date status, view the checks section at the bottom of the pull request. |
bf43b6f to
8cd58db
Compare
|
@bvandermoon Could you please review this pr? |
bvandermoon
left a comment
There was a problem hiding this comment.
Hey @JamesBrianD, thanks for reaching out and for your contribution. Could you please provide a little more context around this PR?
|
Hello @bvandermoon, thanks for reviewing! Happy to provide context. |
|
Hi @bvandermoon, friendly ping on this PR. |
|
Thanks for the ping @JamesBrianD. I'm not an SGLang expert, I'll try to find someone more knowledgable to review this. |
depksingh
left a comment
There was a problem hiding this comment.
two minor comments, otherwise LGTM
| - **TPU chips per node**: 4 (v6e) | ||
| - **Total TPU chips**: 64 | ||
| - **Tensor Parallelism (TP)**: 32 (for non-MoE layers) | ||
| - **Expert Tensor Parallelism (ETP)**: 64 (for MoE experts) |
There was a problem hiding this comment.
nit: can we please add the tpu provisioning and ssh commands like the Qwen3 readme, so that users looking at only this readme are aware of the steps.
|
|
||
| ### Launch Command | ||
|
|
||
| Run the following command **on each node**, replacing: |
There was a problem hiding this comment.
This seems like a manual process of running the same command on each node by sshing and running the same command which can be time taking. Can you please check if the below version of command can be used to run the same command on all the workers which will simplify the process.
gcloud compute tpus tpu-vm ssh tpu-name --zone=zone --worker=all --command='pip install "jax[tpu]==0.4.20" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
https://docs.cloud.google.com/tpu/docs/managing-tpus-tpu-vm
since node-rank is the only changing param for each node, is there any other way to pass it so that it doesn't depend on the command, that way we'll be able to run the same command on all nodes with the above single command. If not, I think the existing way should be fine.
This is a tutorial for running the sglang-jax project on TPUs.