Skip to content

Commit 577104b

Browse files
authored
Support TP 🎉 (#72)
* Initial support for TP * Use random initialization * Fix PP forward * Downgrade to torch 2.6.0 * Fix env setting for MAX_JOBS * Downgrade to torch 2.5.1 * Fix TP group init * Fix annotation * Make llama compatible for tp * Make chatglm compatible for TP * Make Qwen3 compatible for TP * Remove weight_loader in fused_moe * Make fused_moe compatible for TP; Abstract weight load function * Make qwen_moe compatible for tp * Make mixtral compatible for TP * Update readme * Abstract module attention; Clean up code for TP attention; Clean up code for model weights loading for glm * Add MoE tuing config for A100 PCIE 40GB * Refactor scheduler.py and AllocatorID * Refactor IDAllocator * Refactor worker scheduler * Update readme * Make embed_tokens and lm_head compatible for TP * Fix multi-node zmq_comm * Bump version to 0.1.0
1 parent e38eae8 commit 577104b

33 files changed

+1528
-539
lines changed

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ set(CUDA_SUPPORTED_ARCHS "7.0;7.5;8.0;8.6;8.9;9.0")
2121

2222

2323
# Supported/expected torch versions for CUDA.
24-
set(TORCH_SUPPORTED_VERSION_CUDA "2.7.0")
24+
set(TORCH_SUPPORTED_VERSION_CUDA "2.5.1")
2525

2626
#
2727
# Try to find python package with an executable that exactly matches

README.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,10 @@ Global Balanced Pipeline Parallelism System for Distributed LLM Serving with Tok
1515
<img src=doc/pic/overview.svg width=500>
1616
</p>
1717

18-
Integreted with features like **continuous batching**, **paged attention**, **chunked prefill**, **prefix caching**, **token throttling** and **pipeline parallelism**, gLLM provides basic functionality (offline/online inference and interactive chat) to support large language model inference. gLLM provides **equivalent or superior** offline/online inference speed with mainstream inference engine and **minimal** (~4k loc) code base. You can also see gLLM as a LLM inference playground for doing experiment or academic research.
18+
Integreted with features like **continuous batching**, **paged attention**, **chunked prefill**, **prefix caching**, **token throttling**, **pipeline parallelism** and **tensor parallelism**, gLLM provides basic functionality (**offline/online inference and interactive chat**) to deploy distributed LLMs (**supported in huggingface**) inference. gLLM provides **equivalent or superior** offline/online inference speed with mainstream inference engine and **minimal** (~6k loc) code base. You can also see gLLM as a LLM inference playground for doing experiment or academic research.
1919

2020
*Latest News* :fire:
21+
- [2025/06/14]: Tensor parallelism is now integrated, allowing joint deploying with pipeline parallelism :sunglasses:
2122
- [2025/05/05]: MoE architecture is supported. Try Qwen2/3 MoE models :star_struck:
2223
- [2025/04/29]: Qwen3 day 1 support. Come and try Qwen3 :tada:
2324
- [2025/04/27]: gLLM is open sourced :earth_asia:
@@ -43,7 +44,7 @@ Integreted with features like **continuous batching**, **paged attention**, **ch
4344

4445
## Install gLLM
4546
```
46-
pip install torch==2.7.0
47+
pip install torch==2.5.1
4748
pip install -v -e .
4849
```
4950

@@ -73,7 +74,7 @@ python benchmarks/benchmark_throughput.py --model $MODEL \
7374
```
7475
# To see the description of args, run 'python -m gllm.entrypoints.api_server -h'
7576
python -m gllm.entrypoints.api_server --port $PORT --model-path $MODEL_PATH \
76-
--enable-prefix-caching --pp $PP
77+
--enable-prefix-caching --pp $PP --tp $TP
7778
```
7879

7980
### Launch OpenAI-Compatible Server (Multi-node)
@@ -142,13 +143,12 @@ python evaluations/evaluate_MMLU_pro.py --model $MODEL --port $PORT
142143
## Supported Models
143144

144145
- Qwen Series: Qwen3, Qwen2.5, Qwen2
145-
- Llama Series: Llama3.1, Llama3, Llama2 and deepseek-coder
146+
- Llama Series: Llama3.2, Llama3.1, Llama3, Llama2 and deepseek-coder
146147
- Mixtral Series: Mixtral-8x7B, Mixtral-8x22B
147-
- ChatGLM Series: Chatglm3 and glm4
148+
- ChatGLM Series: Glm4 and Chatglm3
148149

149150
## Roadmap
150151

151-
- [ ] Support TP
152152
- [ ] Support more models
153153

154154

examples/chat_client.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,14 @@
2222
messages = []
2323

2424
print("\nWelcome to the chatbot!\n"
25-
"Type '\exit' to exit the chatbot.\n"
26-
"Type '\clear' to clear the chatbot's history.\n")
25+
"Type '\\exit' to exit the chatbot.\n"
26+
"Type '\\clear' to clear the chatbot's history.\n")
2727

2828
while True:
2929
prompt = input('>>> ')
30-
if prompt == '\exit':
30+
if prompt == '\\exit':
3131
break
32-
elif prompt == '\clear':
32+
elif prompt == '\\clear':
3333
messages = []
3434
messages.append({'role': 'user', 'content': prompt})
3535
chat_completion = client.chat.completions.create(

gllm/async_llm_engine.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from gllm.worker import Worker, run_worker
1414
from gllm.input_data import InputData
1515
from gllm.sequence import Sequence
16-
from gllm.scheduler import IPCPackage
16+
from gllm.frontend_scheduler import IPCPackage
1717
from gllm.zmq_comm import zmqComm
1818

1919

@@ -142,7 +142,7 @@ def __init__(self, *args, **kwargs):
142142
self.act_worker_ranks = [int(i) for i in self.worker_ranks.split(',')]
143143
assert len(self.act_worker_ranks) != 0
144144
else:
145-
self.act_worker_ranks = list(range(self.pp_size))
145+
self.act_worker_ranks = list(range(self.pp_size*self.tp_size))
146146
self.num_workers = len(self.act_worker_ranks)
147147

148148
self.ctx = mp.get_context('spawn')
@@ -156,7 +156,7 @@ def __init__(self, *args, **kwargs):
156156
self.token_path = f'ipc:///tmp/{ipc_path_prefix}_gllm_token'
157157

158158
self.comm = zmqComm(self.host, self.zmq_port_base, self.launch_mode, self.master_addr,
159-
0, 0, self.schedule_path, self.output_path, self.token_path)
159+
self.schedule_path, self.output_path, self.token_path, frontend=True)
160160
self.comm.init()
161161

162162
logger.info(f'Launching worker {self.act_worker_ranks} ...')
@@ -166,8 +166,10 @@ def __init__(self, *args, **kwargs):
166166
logger.warning(f'Multi-node support is an experimental feature')
167167

168168
self.process_list = []
169-
for local_rank, pp_rank in enumerate(self.act_worker_ranks):
170-
self.start_worker(local_rank, pp_rank)
169+
for local_rank, rank in enumerate(self.act_worker_ranks):
170+
pp_rank = rank // self.tp_size
171+
tp_rank = rank % self.tp_size
172+
self.start_worker(local_rank, pp_rank, tp_rank)
171173

172174
if kwargs['load_format'] == 'auto':
173175
self.load_progress()
@@ -256,21 +258,21 @@ async def run_schedule_engine(self):
256258
self.send_ipc_package()
257259
await asyncio.sleep(0)
258260

259-
def start_worker(self, local_rank, pp_rank):
261+
def start_worker(self, local_rank, pp_rank, tp_rank):
260262
worker_cls = Worker if not self.use_async_worker else AsyncWorker
261263
comm = zmqComm(self.host,
262264
self.zmq_port_base,
263265
self.launch_mode,
264266
self.master_addr,
265-
pp_rank,
266-
self.pp_size,
267267
self.schedule_path,
268268
self.output_path,
269269
self.token_path)
270270
worker = worker_cls(self.model_runner,
271271
local_rank,
272272
pp_rank,
273+
tp_rank,
273274
self.pp_size,
275+
self.tp_size,
274276
self.master_addr,
275277
self.master_port,
276278
comm,

gllm/async_worker.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@ def __init__(self, *args, **kwargs):
1818
async def run_driver(self):
1919
return super().run_driver()
2020

21+
@async_wrapper
22+
async def run_first_tp(self):
23+
return super().run_first_tp()
24+
2125
@async_wrapper
2226
async def run_other(self):
2327
return super().run_other()
@@ -36,8 +40,10 @@ async def launch_async_tasks(worker: AsyncWorker):
3640
worker.init()
3741

3842
ats = AsyncTasks()
39-
if worker.pp_rank == 0:
43+
if worker.rank == 0:
4044
ats.add_task(worker.run_driver)
45+
elif worker.pp_rank == 0:
46+
ats.add_task(worker.run_first_tp)
4147
else:
4248
ats.add_task(worker.run_other)
4349
await ats.wait()

gllm/dist_utils.py

Lines changed: 135 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,25 @@
22
import torch
33

44
from logger import logger
5+
from collections.abc import Sequence
56

67
def send_pp_data(output, dst):
78
if type(output) == tuple:
89
assert len(output) == 2
9-
dist.isend(output[0],dst)
10-
dist.isend(output[1],dst)
10+
dist.isend(output[0], dst)
11+
dist.isend(output[1], dst)
1112
else:
12-
dist.isend(output,dst)
13+
dist.isend(output, dst)
1314

1415
def recv_pp_data(src, shape, has_residual):
1516
hidden_states = torch.zeros(torch.Size(shape))
1617
if has_residual:
1718
residual = hidden_states.clone().detach()
18-
hidden_states_future = dist.irecv(hidden_states,src)
19-
residual_future = dist.irecv(residual,src)
19+
hidden_states_future = dist.irecv(hidden_states, src)
20+
residual_future = dist.irecv(residual, src)
2021
return hidden_states_future, residual_future, hidden_states, residual
2122
else:
22-
hidden_states_future = dist.irecv(hidden_states,src)
23+
hidden_states_future = dist.irecv(hidden_states, src)
2324
return hidden_states_future, hidden_states
2425

2526
def send_obj_list(obj_list, dst):
@@ -28,36 +29,88 @@ def send_obj_list(obj_list, dst):
2829
def recv_obj_list(obj_list, src):
2930
dist.recv_object_list(obj_list, src=src)
3031

32+
_RANK=0
3133
_PP_RANK=0
34+
_TP_RANK=0
3235
_LOCAL_RANK=0
3336
_PP_SIZE=1
37+
_TP_SIZE=1
38+
_WORLD_SIZE=1
3439
_ASSIGNED_LAYERS=None
40+
_TP_GROUP=None
41+
42+
def get_rank():
43+
return _RANK
44+
45+
def get_world_size():
46+
return _WORLD_SIZE
3547

3648
def get_pp_rank():
3749
return _PP_RANK
3850

51+
def get_tp_rank():
52+
return _TP_RANK
53+
3954
def get_local_rank():
4055
return _LOCAL_RANK
4156

42-
def is_pp_last_rank():
57+
def get_output_rank():
58+
return (get_pp_size() - 1) * get_tp_size()
59+
60+
def is_output_rank():
61+
return is_last_pp_rank() and is_first_tp_rank()
62+
63+
def is_first_tp_rank():
64+
return get_tp_rank() == 0
65+
66+
def is_last_pp_rank():
4367
return get_pp_rank() == get_pp_size() - 1
4468

69+
def get_next_pp_rank():
70+
return get_rank() + get_tp_size()
71+
72+
def get_last_pp_rank():
73+
return get_rank() - get_tp_size()
74+
4575
def get_pp_size():
4676
return _PP_SIZE
4777

78+
def get_tp_size():
79+
return _TP_SIZE
80+
4881
def get_assigned_layers():
4982
return _ASSIGNED_LAYERS
5083

51-
def init_dist(pp_size, local_rank, pp_rank, master_addr, master_port, assigned_layers):
52-
global _PP_RANK, _PP_SIZE, _ASSIGNED_LAYERS, _LOCAL_RANK
84+
def get_tp_group():
85+
return _TP_GROUP
86+
87+
def init_tp_group():
88+
global _TP_GROUP
89+
tp_groups = [list(range(_pp_rank*get_tp_size(), (_pp_rank+1)*get_tp_size())) for _pp_rank in range(get_pp_size())]
90+
for tp_ranks in tp_groups:
91+
tp_group = dist.new_group(tp_ranks)
92+
if _RANK in tp_ranks:
93+
_TP_GROUP = tp_group
94+
95+
def init_dist(pp_size, tp_size, local_rank, pp_rank, tp_rank, master_addr, master_port, assigned_layers):
96+
global _RANK, _PP_RANK, _TP_RANK, _PP_SIZE, _TP_SIZE, _WORLD_SIZE, _ASSIGNED_LAYERS, _LOCAL_RANK, _TP_GROUP, _PP_GROUP
97+
_RANK = pp_rank * tp_size + tp_rank
5398
_PP_RANK = pp_rank
99+
_TP_RANK = tp_rank
54100
_LOCAL_RANK = local_rank
55101
_PP_SIZE = pp_size
102+
_TP_SIZE = tp_size
103+
_WORLD_SIZE = pp_size * tp_size
56104
_ASSIGNED_LAYERS = assigned_layers
105+
106+
self_tp_ranks = list(range(pp_rank*tp_size, (pp_rank+1)*tp_size))
107+
57108
init_method = f'tcp://{master_addr}:{master_port}'
58109
backend = 'nccl'
59-
logger.info(f'NCCL: Init_method {init_method}, Backend {backend}, Word_size {pp_size}')
60-
dist.init_process_group(init_method=init_method, backend=backend, world_size=pp_size, rank=pp_rank)
110+
logger.info(f'NCCL: Init_method {init_method}, Backend {backend}, Rank {_RANK}, TP Groups {self_tp_ranks}, Word_size {_WORLD_SIZE}')
111+
dist.init_process_group(init_method=init_method, backend=backend, world_size=_WORLD_SIZE, rank=_RANK)
112+
113+
init_tp_group()
61114

62115
def get_pp_layers(num_layers):
63116
if _ASSIGNED_LAYERS is None:
@@ -93,3 +146,74 @@ def resolve_pp_layer(layer_name, idx, start_layer_idx):
93146
return '.'.join(layer_name_list)
94147
else:
95148
return layer_name
149+
150+
def tensor_model_parallel_all_gather(input_: torch.Tensor, dim=-1) -> torch.Tensor:
151+
"""All-gather the input tensor across model parallel group."""
152+
if dim < 0:
153+
# Convert negative dim to positive.
154+
dim += input_.dim()
155+
input_size = input_.size()
156+
# NOTE: we have to use concat-style all-gather here,
157+
# stack-style all-gather has compatibility issues with
158+
# torch.compile . see https://github.com/pytorch/pytorch/issues/138795
159+
output_size = (input_size[0] * get_tp_size(), ) + input_size[1:]
160+
# Allocate output tensor.
161+
output_tensor = torch.empty(output_size,
162+
dtype=input_.dtype,
163+
device=input_.device)
164+
# All-gather.
165+
dist.all_gather_into_tensor(output_tensor,
166+
input_,
167+
group=get_tp_group())
168+
# Reshape
169+
output_tensor = output_tensor.reshape((get_tp_size(), ) + input_size)
170+
output_tensor = output_tensor.movedim(0, dim)
171+
output_tensor = output_tensor.reshape(input_size[:dim] +
172+
(get_tp_size() *
173+
input_size[dim], ) +
174+
input_size[dim + 1:])
175+
return output_tensor
176+
177+
def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
178+
"""All-reduce the input tensor across model parallel group."""
179+
dist.all_reduce(input_, group=get_tp_group())
180+
return input_
181+
182+
def ensure_divisibility(numerator, denominator):
183+
"""Ensure that numerator is divisible by the denominator."""
184+
assert numerator % denominator == 0, "{} is not divisible by {}".format(
185+
numerator, denominator)
186+
187+
188+
def divide(numerator, denominator):
189+
"""Ensure that numerator is divisible by the denominator and return
190+
the division value."""
191+
ensure_divisibility(numerator, denominator)
192+
return numerator // denominator
193+
194+
def split_tensor_along_last_dim(
195+
tensor: torch.Tensor,
196+
num_partitions: int,
197+
contiguous_split_chunks: bool = False,
198+
) -> Sequence[torch.Tensor]:
199+
""" Split a tensor along its last dimension.
200+
201+
Arguments:
202+
tensor: input tensor.
203+
num_partitions: number of partitions to split the tensor
204+
contiguous_split_chunks: If True, make each chunk contiguous
205+
in memory.
206+
207+
Returns:
208+
A list of Tensors
209+
"""
210+
# Get the size and dimension.
211+
last_dim = tensor.dim() - 1
212+
last_dim_size = divide(tensor.size()[last_dim], num_partitions)
213+
# Split.
214+
tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
215+
# NOTE: torch.split does not create contiguous tensors by default.
216+
if contiguous_split_chunks:
217+
return tuple(chunk.contiguous() for chunk in tensor_list)
218+
219+
return tensor_list

gllm/entrypoints/api_server.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ async def run_server(args):
100100
parser.add_argument('--use-naive-schedule', help='Use scheduling policy in Sarathi-Serve', action='store_true')
101101
parser.add_argument('--enable-prefix-caching', help='Enable KV cache reuse across requests', action='store_true')
102102
parser.add_argument('--pp', type=int, help='Number of pipeline stages', default=1)
103+
parser.add_argument('--tp', type=int, help='Number of tensor parallel degrees', default=1)
103104
parser.add_argument('--load-format', type=str, choices=['auto','dummy'], help='auto: actually load model weights; dummy: initialize the model with random values', default='auto')
104105
parser.add_argument('--assigned-layers', type=str, help='If the model have 64 layers, we can set it to 16,16,16,16 or 16,16,17,15', default=None)
105106
parser.add_argument('--use-async-worker', help='Experimental feature for worker implemented by async', action='store_true')
@@ -125,6 +126,7 @@ async def run_server(args):
125126
kvthresh=args.kvthresh,
126127
enable_prefix_caching=args.enable_prefix_caching,
127128
pp_size=args.pp,
129+
tp_size=args.tp,
128130
assigned_layers=args.assigned_layers,
129131
use_naive_schedule=args.use_naive_schedule,
130132
use_async_worker=args.use_async_worker)
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def __init__(self, schedule_lists: List[Sequence]):
1717
self.act_schedule_ids = []
1818
self.next_tokens = []
1919

20-
class Scheduler:
20+
class FrontendScheduler:
2121
def __init__(self, maxd: int, maxp: int, kvthresh: float,
2222
page_size: int) -> None:
2323
self.prompt_lists: List[Sequence] = [] # seqs to prefill

0 commit comments

Comments
 (0)