Skip to content

Commit 105adb0

Browse files
authored
Update sglang v0.4 blog (#128)
1 parent d3ad6d0 commit 105adb0

File tree

7 files changed

+169
-0
lines changed

7 files changed

+169
-0
lines changed

blog/2024-12-03-sglang-v0-4.md

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
---
2+
title: "SGLang v0.4: Zero-Overhead Batch Scheduler, Cache-Aware Load Balancer, Faster Structured Outputs"
3+
author: "The SGLang Team"
4+
date: "December 3, 2024"
5+
previewImg: /images/blog/sglang_v0_4/nsys_no_idle.jpg
6+
---
7+
8+
We’re excited to release SGLang v0.4, featuring significant performance improvements and new features:
9+
10+
- Zero-overhead batch scheduler: 1.1x increase in throughput.
11+
- Cache-aware load balancer: up to 1.9x increase in throughput with 3.8x higher cache hit rate.
12+
- Data parallelism attention for DeepSeek models: up to 1.9x decoding throughput improvement.
13+
- Fast structured outputs with xgrammar: up to 10x faster.
14+
15+
This blog provides a walkthrough of these updates. We welcome your feedback and contributions!
16+
17+
## Zero-Overhead Batch Scheduler
18+
19+
While LLM inference runs on GPUs, there is substantial work that also needs to be done by the CPU, such as batch scheduling, memory allocation, and prefix matching. An unoptimized inference engine can spend as much as [half of its time on CPU overhead](https://mlsys.wuklab.io/posts/scheduling_overhead/). SGLang has been known for its efficient batch scheduler from the start. In this new version, we pushed it to the extreme and achieved a near zero-overhead batch scheduler. This idea is simple and has been proposed in [NanoFlow](https://arxiv.org/abs/2408.12757). Basically, we can overlap the CPU scheduling with the GPU computation. The scheduler runs one batch ahead and prepares all the metadata required for the next batch. By doing this, we can keep the GPUs always busy and hide expensive overheads such as the radix cache operations. The related code is [here](https://github.com/sgl-project/sglang/blob/85e1a6f3aa5a2288ca85fe3fe922c733b6533fa7/python/sglang/srt/managers/scheduler.py#L399). The implementation details involve resolving dependencies by creating future tokens and carefully scheduling CUDA events and synchronization. Below is an illustration of the overlapped CPU scheduler and GPU worker.
20+
21+
<img src="/images/blog/sglang_v0_4/scheduler.jpg" style="display: flex; margin-top: auto; margin-left: auto; margin-right: auto; margin-bottom: auto; width: 100%;"></img>
22+
23+
We verified the zero-overhead claim by using the Nsight profiling system. In the figure below, there are 5 consecutive decoding batches, and you can see there is no single idle time on the GPU. (NOTE: This profile is obtained with the Triton attention backend; there is still a minor gap with the FlashInfer backend, which will be resolved in the next FlashInfer release.)
24+
25+
<img src="/images/blog/sglang_v0_4/nsys_no_idle.jpg" style="display: flex; margin-top: auto; margin-left: auto; margin-right: auto; margin-bottom: auto; width: 90%;"></img>
26+
27+
With this optimization, SGLang v0.4 can now squeeze the last bit of performance from the GPU and achieves a 1.1x speedup against its previous version and a 1.3x speedup against other state-of-the-art baselines. The speedup is most significant on small models and large tensor parallelism sizes.
28+
29+
<img src="/images/blog/sglang_v0_4/llama_3_2_3b.svg" style="display: flex; margin-top: auto; margin-left: auto; margin-right: auto; margin-bottom: auto; width: 60%;"></img>
30+
31+
32+
**Usage**: It is turned on by default, so you do not need to change anything!
33+
34+
**Reproduce benchmark**:
35+
```
36+
# zero-overhead batch scheduler (v0.4)
37+
python3 -m sglang.launch_server --model meta-llama/Llama-3.2-3B-Instruct
38+
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompts 500 --random-input 4096 --random-output 2048
39+
40+
# old batch scheduler (v0.3)
41+
python3 -m sglang.launch_server --model meta-llama/Llama-3.2-3B-Instruct --disable-overlap
42+
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompts 500 --random-input 4096 --random-output 2048
43+
```
44+
45+
## Cache-Aware Load Balancer
46+
47+
SGLang v0.4 introduces a cache-aware load balancer for LLM inference engines. The load balancer predicts prefix KV cache hit rates on workers and selects those with the highest match rates. Testing shows a **up to 1.9x throughput increase and 3.8x hit rate improvement**, with benefits scaling as worker count increases. The figure below shows how a cache-aware load balancer is different from a naive round-robin load balancer for data parallelism. The cache-aware load balancer maintains an approximate radix tree of the actual radix tree on the workers. The tree is lazily updated with almost no overhead.
48+
49+
<img src="/images/blog/sglang_v0_4/cache_aware.png" style="display: flex; margin-top: auto; margin-left: auto; margin-right: auto; margin-bottom: auto; width: 90%;"></img>
50+
51+
Here are some benchmark results. The new cache-aware router significantly improves throughput.
52+
53+
| | SGLang v0.4 | SGLang v0.3 |
54+
| :---- | :---- | :---- |
55+
| Throughput (token/s) | 158596 | 82665 |
56+
| Cache hit rate | 75% | 20% |
57+
58+
> The benchmark is conducted on a [workload](https://github.com/sgl-project/sglang/pull/1990) that has multiple long prefix groups, and each group is perfectly balanced. The performance might vary based on the characteristics of the workload, but it should improve the cache hit rate significantly
59+
60+
The key features of this router includes
61+
- **Multi-Node Support**: Deploy workers across multiple machines, connect a single router to distributed workers, allowing for easy horizontal scaling while preserving cache awareness in a distributed setup.
62+
- **Cache-Aware Routing**: Requests are sent to workers with a higher hit rate, and load balancing is performed to avoid imbalance.
63+
- **Communication-Free Design**: No worker synchronization is required for cache state; instead, it uses passed information to simulate an "approximate tree".
64+
- **High-Performance Implementation**: Built in pure Rust for high concurrency, with a low overhead design, offering a 2x speedup compared to Python-based alternatives.
65+
- **Standalone Package**: Published as "sglang-router", includes Python bindings, and features a CLI interface for easy usage.
66+
67+
### Usage
68+
Installation:
69+
```
70+
pip install sglang-router
71+
```
72+
73+
1. Co-launch Workers and Router
74+
75+
Drop-in replacement for existing --dp-size parameter:
76+
```
77+
python -m sglang_router.launch_server \
78+
--model-path meta-llama/Meta-Llama-3.1-8B-Instruct \
79+
--dp-size 8
80+
```
81+
82+
2. Router-Only Launch
83+
Ideal for multi-node distributed processing:
84+
```
85+
python -m sglang_router.launch_server \
86+
--model-path meta-llama/Meta-Llama-3.1-8B-Instruct \
87+
--dp-size 8 python -m sglang_router.launch_router \
88+
--worker-urls http://worker1:8000 http://worker2:8000
89+
```
90+
91+
### Reproduce benchmark:
92+
````
93+
# Hardware: 8x A100 80GB GPUs
94+
# Run benchmark
95+
python bench_serving.py \
96+
--host 127.0.0.1 \
97+
--port 30000 \
98+
--dataset-name generated-shared-prefix
99+
100+
# Launch with router
101+
python -m sglang_router.launch_server \
102+
--model-path meta-llama/Meta-Llama-3.1-8B-Instruct \
103+
--dp-size 8
104+
105+
# Launch without router (baseline)
106+
python -m sglang.launch_server \
107+
--model-path meta-llama/Meta-Llama-3.1-8B-Instruct \
108+
--dp-size 8
109+
````
110+
111+
Learn more by reading the [code](https://github.com/sgl-project/sglang/tree/main/rust). There is also a related paper (with a different design and implementation), [Preble](https://arxiv.org/abs/2407.00023), which is also built on top of SGLang.
112+
113+
## Data Parallelism Attention For DeepSeek Models
114+
115+
The most common parallelism strategy for inference is tensor parallelism. However, it might not be the most efficient strategy for certain models. For example, DeepSeek models use MLA and only have one KV head. If we use tensor parallelism on 8 GPUs, it will lead to duplicated KV cache and unwanted memory usage.
116+
117+
To overcome this, we've implemented data parallelism (DP) for the multi-head latent attention (MLA) mechanism to improve throughput for DeepSeek models. By adopting DP for the attention component, the KV cache is significantly reduced, allowing for larger batch sizes. In our DP attention implementation, each DP worker handles different types of batches (prefill, decode, idle) independently. The attention-processed data will be all-gathered among all workers before entering the Mixture-of-Experts (MoE) layer, and after processing through the MoE, the data will be redistributed back to each worker. The figure below illustrates this idea.
118+
119+
<img src="/images/blog/sglang_v0_4/dp_attention.svg" style="display: flex; margin-top: auto; margin-left: auto; margin-right: auto; margin-bottom: auto; width: 50%;"></img>
120+
121+
Here are the benchmark results on 8 x H100 80GB GPUs. With this optimization, SGLang v0.4 achieved 1.9x decoding throughput compared to SGLang v0.3. We are working on further improving the throughput by integrating expert parallelism for the MoE layers. You can check out the related PRs for [data parallelism](https://github.com/sgl-project/sglang/pull/1970) and [expert parallelism](https://github.com/sgl-project/sglang/pull/2203).
122+
123+
<img src="/images/blog/sglang_v0_4/deepseek_coder_v2.svg" style="display: flex; margin-top: auto; margin-left: auto; margin-right: auto; margin-bottom: auto; width: 60%;"></img>
124+
125+
**Usage:** Add `--enable-dp-attention` option to turn on this feature. Currently, it’s only supported for DeepSeek models.
126+
127+
**Reproduce benchmark:**
128+
```
129+
# Hardware: 8x H100 80GB GPUs
130+
# If you see out-of-memory, please try to reduce `--mem-fraction-static` to a smaller value such as 0.75.
131+
132+
# SGLang w/ DP attention (v0.4)
133+
python3 -m sglang.launch_server --model-path neuralmagic/DeepSeek-Coder-V2-Instruct-FP8 --disable-radix-cache --trust-remote-code --tp 8 --enable-dp-attention --mem-fraction-static 0.78
134+
python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 1 --random-output 512 --random-range-ratio 1 --num-prompts 10000
135+
136+
# SGLang w/o DP attention (v0.3)
137+
python3 -m sglang.launch_server --model-path neuralmagic/DeepSeek-Coder-V2-Instruct-FP8 --disable-radix-cache --trust-remote-code --tp 8 --mem-fraction-static 0.78
138+
python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 1 --random-output 512 --random-range-ratio 1 --num-prompts 10000
139+
```
140+
141+
## Fast Structured Outputs with XGrammar
142+
143+
SGLang has been the fastest inference engine for JSON decoding with its [Compressed Finite State Machine](https://lmsys.org/blog/2024-02-05-compressed-fsm/). With this new release, it becomes even faster by integrating a faster grammar backend, xgrammar.
144+
According to the benchmark results, **SGLang \+ xgrammar can be up to 10x faster than other open-source solutions for JSON decoding tasks**. You can learn more in the xgrammar blog post:
145+
[https://blog.mlc.ai/2024/11/22/achieving-efficient-flexible-portable-structured-generation-with-xgrammar](https://blog.mlc.ai/2024/11/22/achieving-efficient-flexible-portable-structured-generation-with-xgrammar).
146+
147+
**Usage**: Add \`--grammar-backend xgrammar\` when launching the server.
148+
```
149+
python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --grammar-backend xgrammar
150+
```
151+
152+
You can then query it with the OpenAI-compatible API. See an example at [https://sgl-project.github.io/backend/openai\_api\_completions.html\#JSON](https://sgl-project.github.io/backend/openai_api_completions.html#JSON).
153+
154+
## Acknowledgment
155+
156+
The work in this blog post is mainly contributed by Byron Hsu, Ke Bao, Lianmin Zheng, Yineng Zhang, and Ziyi Xu. We thank Zhiqiang Xie, Liangsheng Yin, Shuo Yang, and Yilong Zhao for their discussions on the zero-overhead scheduler; Ying Sheng, Yichuan Wang, and Shiyi Cao for their discussions on the cache-aware load balancer; Jiashi Li for their discussion on data parallelism attention; and Yixin Dong for the amazing xgrammar library.
157+
158+
159+
## Roadmap
160+
161+
It has been a great year, and we delivered many features following our [roadmap](https://github.com/sgl-project/sglang/issues/1487).
162+
The community is also growing healthily with more developers and adoption.
163+
The focus of the next release will be on disaggregated prefill-decode, speculative decoding, multi-level radix cache, sequence parallelism, and more!
623 KB
Loading

public/images/blog/sglang_v0_4/deepseek_coder_v2.svg

Lines changed: 1 addition & 0 deletions
Loading

public/images/blog/sglang_v0_4/dp_attention.svg

Lines changed: 4 additions & 0 deletions
Loading

public/images/blog/sglang_v0_4/llama_3_2_3b.svg

Lines changed: 1 addition & 0 deletions
Loading
244 KB
Loading
127 KB
Loading

0 commit comments

Comments
 (0)