Skip to content

Commit 4baae54

Browse files
authored
Merge pull request #228 from lm-sys/blog/sglang-jax
SGLang-Jax Blog
2 parents 2484a26 + 0dbdd54 commit 4baae54

File tree

9 files changed

+164
-0
lines changed

9 files changed

+164
-0
lines changed

blog/2025-10-29-sglang-jax.md

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
---
2+
title: "SGLang-Jax: An Open-Source Solution for Native TPU Inference"
3+
author: "The SGLang-Jax Team"
4+
date: "October 29, 2025"
5+
previewImg: /images/blog/sglang_jax/cover.jpg
6+
---
7+
8+
We're excited to introduce SGLang-Jax, a state-of-the-art open-source inference engine built entirely on Jax and XLA.
9+
It leverages SGLang's high-performance server architecture and uses Jax to compile the model's forward pass.
10+
By combining SGLang and Jax, this project delivers fast, native TPU inference while maintaining support for advanced features like continuous batching, prefix caching, tensor and expert parallelism, speculative decoding, kernel fusion, and highly optimized TPU kernels.
11+
12+
Benchmarks show that SGLang-Jax matches or outperforms other TPU inference solutions.
13+
The source code is available at [https://github.com/sgl-project/sglang-jax](https://github.com/sgl-project/sglang-jax).
14+
15+
## Why a Jax Backend?
16+
17+
While SGLang was originally built on PyTorch, the community has been eager for Jax support.
18+
We built a Jax backend for several key reasons:
19+
20+
- Jax is designed from the ground up for TPUs. For maximum performance without compromise, Jax is the clear choice. With Google expanding public access to TPUs, we expect Jax + TPU to gain significant traction and enable cost-efficient inference.
21+
- Leading AI labs—including Google DeepMind, xAI, Anthropic, and Apple—already rely on Jax. Using the same framework for both training and inference reduces maintenance overhead and eliminates drift between the two stages.
22+
- Jax + XLA is a proven, compilation-driven stack that excels on TPUs and performs well across a broad range of custom TPU-like AI chips.
23+
24+
## Architecture
25+
26+
The diagram below illustrates the SGLang-Jax architecture. The entire stack is pure Jax, resulting in clean code with minimal dependencies.
27+
28+
On the input side, it accepts requests via OpenAI-compatible APIs and utilizes SGLang's efficient RadixCache for prefix caching along with its overlap scheduler for low-overhead batching.
29+
The scheduler pre-compiles Jax computation graphs for different batch sizes.
30+
On the model side, we implement models in Flax and use `shard_map` for various parallelism strategies.
31+
The two core operators—attention and MoE—are implemented as custom Pallas kernels.
32+
33+
<img src="/images/blog/sglang_jax/architecture.png" style="display:block; margin: auto; width: 85%;"></img>
34+
<p style="color:gray; text-align: center;">The architecture of SGLang-Jax</p>
35+
36+
## Key Optimizations
37+
38+
### Integrating Ragged Paged Attention v3
39+
We integrated Ragged Paged Attention V3 ([RPA v3](https://github.com/vllm-project/tpu-inference/tree/main/tpu_inference/kernels/ragged_paged_attention/v3)) and extended it to support SGLang features:
40+
- To support EAGLE speculative decoding, we added custom mask to RPA v3 for use in the verification phase.
41+
- We tuned kernel grid block configurations based on different scenarios to achieve better performance.
42+
- We made it compatible with RadixCache.
43+
44+
### Reducing Scheduling Overhead
45+
Sequential operations on CPU and TPU during the forward pass can hurt performance. However, operations on different devices can be decoupled—for example, launching calculations on the TPU and immediately preparing the next batch to run. To improve performance, our scheduler overlaps CPU processing with TPU computation.
46+
47+
In the overlap event loop, the scheduler uses a result queue and threading events to pipeline CPU and TPU work. While the TPU processes batch N, the CPU prepares batch N+1. To maximize overlap between CPU and TPU, SGLang-jax carefully sequences operations based on profiling results. For Qwen/Qwen3-32B, we reduced the time gap between prefilling and decoding from approximately 12ms to 38us, and from approximately 7ms to 24us. More details can be found in our previous [blog](https://lmsys.org/blog/2024-12-04-sglang-v0-4/).
48+
49+
<img src="/images/blog/sglang_jax/profile_overlap.jpg" style="display:block; margin: auto; width: 85%;"></img>
50+
<p style="color:gray; text-align: center;">Profile with overlap scheduler. The gaps between batches are minimal.</p>
51+
52+
<img src="/images/blog/sglang_jax/profile_no_overlap.jpg" style="display:block; margin: auto; width: 85%;"></img>
53+
<p style="color:gray; text-align: center;">Profile without overlap scheduler. Note the large gaps (CPU overhead) between batches.</p>
54+
55+
### MoE Kernel Optimization
56+
The MoE layer currently supports two implementation strategies: EPMoE and FusedMoE.
57+
In EPMoE, we integrated the **Megablox GMM** operator, replacing the previous jax `ragged_dot`-based implementation.
58+
Megablox GMM is specifically designed for MoE workloads and efficiently handles variable-sized expert groups described by group_sizes, eliminating unnecessary computation and non-contiguous memory accesses. In typical configurations, this operator delivers a **3–4× end-to-end (e2e) ITL speedup** compared to jax's native ragged_dot implementation.
59+
Combined with efficient token permutation (permute/unpermute), expert-parallel communication via ragged_all_to_all, and adaptive tiling strategies, EPMoE significantly boosts overall throughput and works well in scenarios requiring cross-device parallelism with many experts.
60+
In contrast, FusedMoE fuses all expert computations using dense einsum operations without inter-device communication overhead. It's better suited for cases with large individual experts but few total experts (e.g., < 64 experts). It also serves as a lightweight fallback for easier debugging and correctness validation.
61+
62+
### Speculative Decoding
63+
SGLang-jax implements EAGLE-based speculative decoding, which is also known as Multi-Token Prediction (MTP).
64+
This advanced speculative decoding technique accelerates generation by using a lightweight draft head to predict multiple tokens, which are then verified in parallel with a single pass through the full model.
65+
To implement tree-based MTP-Verify, SGLang-jax adds non-causal mask support on top of Ragged Paged Attention V3, enabling parallel decoding of tree-based, non-causal draft tokens during the verification phase.
66+
We currently support Eagle2 and Eagle3, and plan to continue optimizing the kernel implementation and add support for different attention backends at various MTP stages.
67+
68+
## TPU Performance
69+
After all the optimizations, SGLang-Jax can match or outperform other TPU inference solutions.
70+
71+
### Setup
72+
We benchmarked SGLang-Jax against vLLM-TPU. Full instructions are available [here](https://github.com/sgl-project/sglang-jax/issues/270).
73+
We used `Qwen/Qwen3-32B`, TPU v6e-4, SGLang-jax (version: main-af32f095880ff676ed23eec19bc79584b5e20717), and vLLM-tpu (vllm-tpu==0.11.1).
74+
75+
### Results
76+
<img src="/images/blog/sglang_jax/tpu_performance.png" style="display:block; margin: auto; width: 85%;"></img>
77+
<p style="color:gray; text-align: center;">match vllm-tpu on prefill because of similar kernel optimizations. outperform vllm-tpu on decode thanks to overlap scheduler. </p>
78+
79+
<img src="/images/blog/sglang_jax/gpu_performance.png" style="display:block; margin: auto; width: 85%;"></img>
80+
<p style="color:gray; text-align: center;">the TPU setup achieves lower latency (TTFT and ITL) and higher input throughput across various batch sizes</p>
81+
82+
## Usage
83+
84+
### Installing SGLang-Jax and Launching a Server
85+
86+
Install:
87+
```bash
88+
# with uv
89+
uv venv --python 3.12 && source .venv/bin/activate
90+
uv pip install sglang-jax
91+
92+
# from source
93+
git clone https://github.com/sgl-project/sglang-jax
94+
cd sglang-jax
95+
uv venv --python 3.12 && source .venv/bin/activate
96+
uv pip install -e python/
97+
```
98+
99+
Launch a server:
100+
```
101+
MODEL_NAME="Qwen/Qwen3-8B" # or "Qwen/Qwen3-32B"
102+
103+
jax_COMPILATION_CACHE_DIR=/tmp/jit_cache \
104+
uv run python -u -m sgl_jax.launch_server \
105+
--model-path ${MODEL_NAME} \
106+
--trust-remote-code \
107+
--tp-size=4 \
108+
--device=tpu \
109+
--mem-fraction-static=0.8 \
110+
--chunked-prefill-size=2048 \
111+
--download-dir=/tmp \
112+
--dtype=bfloat16 \
113+
--max-running-requests 256 \
114+
--page-size=128
115+
```
116+
117+
### Using TPU via GCP Console
118+
You can find the TPU option under Menu → Compute Engine and click Create TPU in the console.
119+
Note: Only certain zones support specific TPU versions. Remember to set the TPU software version to v2-alpha-tpuv6e.
120+
Under the Compute Engine menu, go to Settings → Metadata, click the SSH Keys button, and add your public key.
121+
Once the TPU server is created, you can log in using the External IP and public key username shown in the console.
122+
See also: https://docs.cloud.google.com/tpu/docs/setup-gcp-account
123+
<img src="/images/blog/sglang_jax/gcp_usage_1.png" style="display:block; margin: auto; width: 85%;"></img>
124+
125+
### Using TPU via Skypilot
126+
We recommend using Skypilot for daily development.
127+
You can quickly set up Skypilot and find scripts for launching development machines and running tests in the sglang-jax repository.
128+
129+
Install Skypilot for GCP: https://docs.skypilot.co/en/latest/getting-started/installation.html#gcp
130+
Then launch [sgl-jax.yaml](https://github.com/sgl-project/sglang-jax/blob/cdd6600a70ecb396382a510da9ea59c91a9ea2c0/scripts/tpu_resource.yaml#L1):
131+
132+
```bash
133+
sky launch sgl-jax.yaml --cluster=sgl-jax-skypilot-v6e-4 --infra=gcp -i 30 --down -y --use-spot
134+
```
135+
136+
This command will find the lowest-cost TPU spot instance across regions and automatically shut down the instance after 30 minutes of idle time. It will also install the sglang-jax environment for you.
137+
Once setup is complete, you can log in directly using `ssh cluster_name` without tracking the external IP address.
138+
139+
140+
## Roadmap
141+
The community is working with Google Cloud team and multiple partners on the following roadmap.
142+
143+
- Model support and optimizations
144+
- Optimize Grok2, Ling/Ring, DeepSeek V3, and GPT-OSS
145+
- Support MiMo-Audio, Wan 2.1, Qwen3 VL
146+
- TPU-optimized kernels
147+
- Quantization kernels
148+
- Communication and computation overlap kernels
149+
- MLA kernels
150+
- RL integration with [tunix](https://github.com/google/tunix)
151+
- Weight synchronization
152+
- Pathways and multi-host support
153+
- Advanced serving features
154+
- Prefill-decode disaggregation
155+
- Hierarchical KV cache
156+
- Multi-LoRA batching
157+
158+
## Acknowledgments
159+
**SGLang-jax team**: sii-xinglong, jimoosciuc, Prayer, aolemila, JamesBrianD, zkkython, neo, leos, pathfinder-pf, Ying Sheng, Hongzhen Chen, Jiacheng Yang, Ke Bao, Qinghan Chen
160+
161+
**Google**: Google Cloud Team
162+
163+
**InclusionAI**: Junping Zhao, Guowei Wang, Yuhong Guo, Zhenxuan Pan
164+
1.03 MB
Loading
311 KB
Loading
402 KB
Loading
458 KB
Loading
460 KB
Loading
297 KB
Loading
347 KB
Loading
156 KB
Loading

0 commit comments

Comments
 (0)