Skip to content

Commit 0263f7a

Browse files
author
niushengxiao
committed
feat: split_n
1 parent 78c1ffd commit 0263f7a

File tree

3 files changed

+102
-16
lines changed

3 files changed

+102
-16
lines changed

lightllm/common/basemodel/basemodel.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def __init__(self, kvargs):
7878
self._init_mem_manager()
7979
self._init_weights()
8080

81-
self.stream_num = 2
81+
self.stream_num = 2
8282
self.graph = [None] * self.stream_num
8383
self.stream = [None] * self.stream_num
8484
for i in range(self.stream_num):
@@ -212,7 +212,8 @@ def _init_cudagraph(self):
212212
self.graph[i] = (
213213
None if self.disable_cudagraph else CudaGraph(self.stream[i], self.graph_max_batch_size, self.graph_max_len_in_batch)
214214
)
215-
self.graph[i].warmup(self, i)
215+
if self.graph[i] is not None:
216+
self.graph[i].warmup(self, i)
216217

217218
def _init_custom(self):
218219
pass

lightllm/distributed/communication_op.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -78,16 +78,11 @@ def set_custom_reduce(self):
7878

7979
# 创建新的 NCCL 组以防止原始 all_reduce 与 cudagraph 卡住
8080
if self.device_group is None:
81-
# self.device_group_list = []
82-
# for _ in range(2):
83-
# device_group = dist.new_group(ranks, backend="nccl")
84-
# self.device_group_list.append(device_group)
8581
self.device_group = dist.new_group(ranks, backend="nccl")
8682

8783
if ENABLE_VLLM_REDUCE and HAS_VLLM:
88-
cpu_group = [dist.new_group(ranks, backend="gloo")] * self.reduce_num
8984
for i in range(self.reduce_num):
90-
self.vllm_reduce[i] = CustomAllreduce(cpu_group[i], torch.cuda.current_device())
85+
self.vllm_reduce[i] = CustomAllreduce([dist.new_group(ranks, backend="gloo")], torch.cuda.current_device())
9186
logger.info("Enable VLLM ALLReduce.")
9287

9388
def _all_reduce_closure(input_, op=ReduceOp.SUM, group=self.device_group, async_op=False, all_reduce_id=0):

lightllm/server/router/model_infer/mode_backend/continues_batch/impl.py

Lines changed: 98 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@ def split_kwargs(
2323
b_seq_len: torch.Tensor,
2424
b_ready_cache_len: torch.Tensor = None,
2525
multimodal_params=None,
26-
is_prefill=True):
26+
is_prefill=True,
27+
split_n=2):
28+
kwargs = [None] * split_n
2729
half_batch = batch_size // 2
2830
b_req_idx1 = b_req_idx[:half_batch]
2931
b_req_idx2 = b_req_idx[half_batch:]
@@ -77,6 +79,92 @@ def split_kwargs(
7779
kwargs2["multimodal_params"] = multimodal_params
7880
return kwargs1, kwargs2
7981

82+
def split_kwargs_n(
83+
batch_size,
84+
total_token_num,
85+
max_len_in_batch,
86+
input_ids: torch.Tensor,
87+
mem_indexes: torch.Tensor,
88+
b_req_idx: torch.Tensor,
89+
b_start_loc: torch.Tensor,
90+
b_seq_len: torch.Tensor,
91+
b_ready_cache_len: torch.Tensor = None,
92+
multimodal_params=None,
93+
is_prefill=True,
94+
split_n=2):
95+
96+
kwargs = [None] * split_n
97+
98+
# 计算每个分片的批次大小
99+
batch_per_split = [batch_size // split_n] * split_n
100+
# 处理不能整除的情况
101+
for i in range(batch_size % split_n):
102+
batch_per_split[i] += 1
103+
104+
# 准备分割索引
105+
batch_indices = [0]
106+
for size in batch_per_split:
107+
batch_indices.append(batch_indices[-1] + size)
108+
109+
# 记录到目前为止的token数
110+
cumulative_tokens = 0
111+
112+
# 为每个分片创建kwargs
113+
for i in range(split_n):
114+
start_idx = batch_indices[i]
115+
end_idx = batch_indices[i+1]
116+
117+
# 分割批次相关的张量
118+
split_b_req_idx = b_req_idx[start_idx:end_idx]
119+
split_b_seq_len = b_seq_len[start_idx:end_idx]
120+
121+
# 计算该分片的token数量
122+
split_tokens = split_b_seq_len.sum().item()
123+
124+
if is_prefill:
125+
# 在prefill阶段,根据token分割
126+
token_start = cumulative_tokens
127+
token_end = token_start + split_tokens
128+
split_input_ids = input_ids[token_start:token_end]
129+
split_mem_indexes = mem_indexes[token_start:token_end]
130+
else:
131+
# 在decode阶段,根据批次分割
132+
split_input_ids = input_ids[start_idx:end_idx]
133+
split_mem_indexes = mem_indexes[start_idx:end_idx]
134+
135+
# 计算此分片的其他参数
136+
split_max_len = split_b_seq_len.max().item() if len(split_b_seq_len) > 0 else 0
137+
split_b_start_loc = split_b_seq_len.cumsum(dim=0) - split_b_seq_len
138+
139+
# 处理缓存长度
140+
split_b_ready_cache_len = None
141+
if b_ready_cache_len is not None:
142+
split_b_ready_cache_len = b_ready_cache_len[start_idx:end_idx]
143+
144+
# 创建kwargs字典
145+
kwargs[i] = {
146+
"batch_size": len(split_b_req_idx),
147+
"total_token_num": split_tokens,
148+
"max_len_in_batch": split_max_len,
149+
"input_ids": split_input_ids,
150+
"mem_indexes": split_mem_indexes,
151+
"b_req_idx": split_b_req_idx,
152+
"b_start_loc": split_b_start_loc,
153+
"b_seq_len": split_b_seq_len,
154+
"b_ready_cache_len": split_b_ready_cache_len,
155+
"is_prefill": is_prefill,
156+
"all_reduce_id": i,
157+
}
158+
159+
# 如果有多模态参数,添加到kwargs
160+
if multimodal_params is not None:
161+
kwargs[i]["multimodal_params"] = multimodal_params
162+
163+
# 更新累计token数
164+
cumulative_tokens += split_tokens
165+
166+
return kwargs
167+
80168
class ContinuesBatchBackend(ModeBackend):
81169
def __init__(self) -> None:
82170
super().__init__()
@@ -106,14 +194,16 @@ def prefill(self, reqs: List[Tuple]):
106194
def decode(self):
107195
kwargs, run_reqs = prepare_decode_inputs(g_infer_context.infer_req_ids)
108196
# logits = self.model.forward(**kwargs)
109-
if kwargs["batch_size"] > 1:
110-
kwargs1, kwargs2 = split_kwargs(**kwargs)
111-
with torch.cuda.stream(self.model.stream[0]):
112-
logits1 = self.model.forward(**kwargs1)
113-
with torch.cuda.stream(self.model.stream[1]):
114-
logits2 = self.model.forward(**kwargs2)
197+
split_n = self.model.stream_num
198+
if kwargs["batch_size"] > split_n - 1:
199+
kwargs_list = split_kwargs_n(**kwargs, split_n=split_n)
200+
logits = [None] * split_n
201+
for i in range(split_n):
202+
with torch.cuda.stream(self.model.stream[i]):
203+
logits[i] = self.model.forward(**kwargs_list[i])
204+
115205
torch.cuda.synchronize()
116-
logits = torch.cat([logits1, logits2], dim=0)
206+
logits = torch.cat(logits, dim=0)
117207
else:
118208
logits = self.model.forward(**kwargs)
119209

0 commit comments

Comments
 (0)