Skip to content

Commit d07ab83

Browse files
committed
Merge branch 'main' into reduce_fix
2 parents 5f89058 + b03b60e commit d07ab83

File tree

17 files changed

+146
-31
lines changed

17 files changed

+146
-31
lines changed

lightllm/common/basemodel/basemodel.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,14 @@ def _check_mem_size(self):
166166
return
167167

168168
def _init_req_manager(self):
169-
self.req_manager = ReqManager(self.max_req_num, self.max_seq_length, self.mem_manager)
169+
create_max_seq_len = 0
170+
171+
if self.batch_max_tokens is not None:
172+
create_max_seq_len = max(create_max_seq_len, self.batch_max_tokens)
173+
if self.max_seq_length is not None:
174+
create_max_seq_len = max(create_max_seq_len, self.max_seq_length)
175+
176+
self.req_manager = ReqManager(self.max_req_num, create_max_seq_len, self.mem_manager)
170177
return
171178

172179
def _init_infer_layer(self):

lightllm/server/api_cli.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,12 @@ def make_argument_parser() -> argparse.ArgumentParser:
1313
)
1414
parser.add_argument("--host", type=str, default="127.0.0.1")
1515
parser.add_argument("--port", type=int, default=8000)
16+
parser.add_argument(
17+
"--zmq_mode",
18+
type=str,
19+
default="ipc:///tmp/",
20+
help="use socket mode or ipc mode, only can be set in ['tcp://', 'ipc:///tmp/']",
21+
)
1622

1723
parser.add_argument(
1824
"--pd_master_ip",

lightllm/server/api_models.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import time
22

3-
from pydantic import BaseModel, Field
3+
from pydantic import BaseModel, Field, field_validator
44
from typing import Dict, List, Optional, Union, Literal
55
import uuid
66

@@ -9,7 +9,7 @@ class ChatCompletionRequest(BaseModel):
99
# The openai api native parameters
1010
model: str
1111
messages: List[Dict[str, str]]
12-
function_call: Optional[str] = 'none'
12+
function_call: Optional[str] = "none"
1313
temperature: Optional[float] = 1
1414
top_p: Optional[float] = 1.0
1515
n: Optional[int] = 1
@@ -52,6 +52,10 @@ class ChatCompletionResponse(BaseModel):
5252
choices: List[ChatCompletionResponseChoice]
5353
usage: UsageInfo
5454

55+
@field_validator("id", mode="before")
56+
def ensure_id_is_str(cls, v):
57+
return str(v)
58+
5559

5660
class DeltaMessage(BaseModel):
5761
role: Optional[str] = None
@@ -70,3 +74,7 @@ class ChatCompletionStreamResponse(BaseModel):
7074
created: int = Field(default_factory=lambda: int(time.time()))
7175
model: str
7276
choices: List[ChatCompletionStreamResponseChoice]
77+
78+
@field_validator("id", mode="before")
79+
def ensure_id_is_str(cls, v):
80+
return str(v)

lightllm/server/api_start.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from lightllm.server import TokenLoad
55
from .api_lightllm import lightllm_generate, lightllm_generate_stream
66
from .api_tgi import tgi_generate_impl, tgi_generate_stream_impl
7-
from lightllm.utils.net_utils import alloc_can_use_network_port
7+
from lightllm.utils.net_utils import alloc_can_use_network_port, PortLocker
88
from lightllm.utils.start_utils import start_submodule_processes
99
from .metrics.manager import start_metric_manager
1010
from .embed_cache.manager import start_cache_manager
@@ -27,6 +27,15 @@ def normal_or_p_d_start(g_objs):
2727
if args.run_mode not in ["normal", "prefill", "decode"]:
2828
return
2929

30+
assert args.zmq_mode in ["tcp://", "ipc:///tmp/"]
31+
32+
# 确保单机上多实列不冲突
33+
if args.zmq_mode == "ipc:///tmp/":
34+
zmq_mode = f"{args.zmq_mode}_{str(args.nccl_port)}_"
35+
args.zmq_mode = None # args 的参数不能直接设置,只能先设置None,再设置才能成功
36+
args.zmq_mode = zmq_mode
37+
logger.info(f"zmq mode head: {args.zmq_mode}")
38+
3039
if args.use_tgi_api:
3140
g_objs.g_generate_func = tgi_generate_impl
3241
g_objs.g_generate_stream_func = tgi_generate_stream_impl
@@ -117,9 +126,18 @@ def normal_or_p_d_start(g_objs):
117126
assert args.data_type in ["fp16", "float16", "bf16", "bfloat16", "fp32", "float32"]
118127

119128
already_uesd_ports = args.visual_nccl_ports + [args.nccl_port, args.port]
129+
if args.run_mode == "decode":
130+
already_uesd_ports = args.visual_nccl_ports + [args.nccl_port, args.port, args.pd_decode_rpyc_port]
131+
132+
# 提前锁定端口,防止在单个机器上启动多个实列的时候,要到模型启动的时候才能
133+
# 捕获到端口设置冲突的问题
134+
ports_locker = PortLocker(already_uesd_ports)
135+
ports_locker.lock_port()
136+
120137
can_use_ports = alloc_can_use_network_port(
121138
num=6 + args.tp + args.tp + args.visual_dp * args.visual_tp, used_nccl_ports=already_uesd_ports
122139
)
140+
logger.info(f"alloced ports: {can_use_ports}")
123141
router_port, detokenization_port, httpserver_port, visual_port, cache_port, metric_port = can_use_ports[0:6]
124142
model_rpc_ports = can_use_ports[6 : 6 + args.tp]
125143
can_use_ports = can_use_ports[6 + args.tp :]
@@ -144,6 +162,8 @@ def normal_or_p_d_start(g_objs):
144162

145163
logger.info(f"all start args:{args}")
146164

165+
ports_locker.release_port()
166+
147167
if args.enable_multimodal:
148168
start_submodule_processes(
149169
start_funcs=[

lightllm/server/detokenization/manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,10 @@ def __init__(
3333
self.args = args
3434
context = zmq.asyncio.Context(2)
3535
self.recv_from_router = context.socket(zmq.PULL)
36-
self.recv_from_router.bind(f"tcp://127.0.0.1:{detokenization_port}")
36+
self.recv_from_router.bind(f"{args.zmq_mode}127.0.0.1:{detokenization_port}")
3737

3838
self.send_to_httpserver = context.socket(zmq.PUSH)
39-
self.send_to_httpserver.connect(f"tcp://127.0.0.1:{httpserver_port}")
39+
self.send_to_httpserver.connect(f"{args.zmq_mode}127.0.0.1:{httpserver_port}")
4040

4141
self.tokenizer = get_tokenizer(model_weightdir, tokenizor_mode, trust_remote_code=trust_remote_code)
4242
self.all_special_ids = set(self.tokenizer.all_special_ids)

lightllm/server/httpserver/manager.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,16 +45,16 @@ def __init__(
4545
self.args = args
4646
context = zmq.asyncio.Context(2)
4747
self.send_to_router = context.socket(zmq.PUSH)
48-
self.send_to_router.connect(f"tcp://127.0.0.1:{router_port}")
48+
self.send_to_router.connect(f"{args.zmq_mode}127.0.0.1:{router_port}")
4949

5050
self.enable_multimodal = enable_multimodal
5151
if self.enable_multimodal:
5252
self.cache_client = rpyc.connect("localhost", cache_port)
5353
self.send_to_visual = context.socket(zmq.PUSH)
54-
self.send_to_visual.connect(f"tcp://127.0.0.1:{visual_port}")
54+
self.send_to_visual.connect(f"{args.zmq_mode}127.0.0.1:{visual_port}")
5555

5656
self.recv_from_detokenization = context.socket(zmq.PULL)
57-
self.recv_from_detokenization.bind(f"tcp://127.0.0.1:{httpserver_port}")
57+
self.recv_from_detokenization.bind(f"{args.zmq_mode}127.0.0.1:{httpserver_port}")
5858

5959
self.tokenizer = get_tokenizer(args.model_dir, args.tokenizer_mode, trust_remote_code=args.trust_remote_code)
6060

@@ -67,6 +67,7 @@ def __init__(
6767
assert self.pd_mode in [NodeRole.P, NodeRole.D, NodeRole.NORMAL]
6868
self.id_gen = ReqIDGenerator()
6969
self.first_time_costs = MovingAverage()
70+
self.per_token_costs = MovingAverage()
7071
# 有的模型的vocab size 读取tokenizer和config.json中不一致
7172
self.vocab_size = max(get_vocab_size(args.model_dir), self.tokenizer.vocab_size)
7273

@@ -340,6 +341,7 @@ async def _wait_to_token_package(
340341
pass
341342
total_cost_time_ms = (time.time() - start_time) * 1000
342343
mean_per_token_cost_time_ms = (total_cost_time_ms - first_token_cost_ms) / out_token_counter
344+
self.per_token_costs.add(mean_per_token_cost_time_ms)
343345
x_request_id = request.headers.get("X-Request-Id", "")
344346
x_session_id = request.headers.get("X-Session-Id", "")
345347
prompt_cache_len = metadata.pop("prompt_cache_len", 0)
@@ -441,6 +443,7 @@ async def timer_to_pd_master(self):
441443
await asyncio.sleep(3)
442444
if log_count % 5 == 0:
443445
logger.info(f"mean first cost: {self.first_time_costs.average()} ms")
446+
logger.info(f"mean per token cost: {self.per_token_costs.average()} ms")
444447

445448
except Exception as e:
446449
logger.error("connetion to pd_master has error")

lightllm/server/httpserver_for_pd_master/manager.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,9 @@ def __init__(
4141
self.id_to_event: Dict[int, asyncio.Event] = {}
4242
self.session = None
4343
self.first_time_costs = MovingAverage()
44-
self.create_session_costs = MovingAverage()
44+
self.prefill_create_session_costs = MovingAverage()
45+
self.decode_create_session_costs = MovingAverage()
46+
self.per_token_costs = MovingAverage()
4547
return
4648

4749
async def register_pd(self, pd_info_json):
@@ -181,7 +183,7 @@ async def fetch_stream(
181183
req = await self._to_req_info(prompt, sampling_params, multimodal_params)
182184
create_start_time = time.time()
183185
async with self.session.post(p_node.to_llm_url(), json=req) as response:
184-
self.create_session_costs.add((time.time() - create_start_time) * 1000)
186+
self.prefill_create_session_costs.add((time.time() - create_start_time) * 1000)
185187
if response.status == 200:
186188
async for line in response.content:
187189
line = line.decode("utf-8").strip()
@@ -217,7 +219,9 @@ async def fetch_stream(
217219
sampling_params.suggested_dp_index = event.upkv_status.dp_index
218220

219221
req = await self._to_req_info(prompt_ids, sampling_params, multimodal_params)
222+
create_start_time = time.time()
220223
async with self.session.post(d_node.to_llm_url(), json=req) as response:
224+
self.decode_create_session_costs.add((time.time() - create_start_time) * 1000)
221225
if response.status == 200:
222226
async for line in response.content:
223227
line = line.decode("utf-8").strip()
@@ -269,6 +273,7 @@ async def _wait_to_token_package(
269273

270274
total_cost_time_ms = (time.time() - start_time) * 1000
271275
mean_per_token_cost_time_ms = (total_cost_time_ms - first_token_cost_ms) / out_token_counter
276+
self.per_token_costs.add(mean_per_token_cost_time_ms)
272277
x_request_id = request.headers.get("X-Request-Id", "")
273278
x_session_id = request.headers.get("X-Session-Id", "")
274279
prompt_cache_len = metadata.pop("prompt_cache_len", 0)
@@ -312,5 +317,7 @@ async def handle_loop(self):
312317
# 可以做一个定时任务
313318
await asyncio.sleep(20)
314319
logger.info(f"mean first cost: {self.first_time_costs.average()} ms")
315-
logger.info(f"create_session_costs: {self.create_session_costs.average()} ms")
320+
logger.info(f"prefill mean create_session_costs: {self.prefill_create_session_costs.average()} ms")
321+
logger.info(f"decode mean create_session_costs: {self.decode_create_session_costs.average()} ms")
322+
logger.info(f"mean per token cost: {self.per_token_costs.average()} ms")
316323
return

lightllm/server/router/manager.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,10 +66,10 @@ def __init__(self, args, router_port, detokenization_port, model_rpc_ports, metr
6666

6767
context = zmq.asyncio.Context(2)
6868
self.recv_from_httpserver = context.socket(zmq.PULL)
69-
self.recv_from_httpserver.bind(f"tcp://127.0.0.1:{router_port}")
69+
self.recv_from_httpserver.bind(f"{args.zmq_mode}127.0.0.1:{router_port}")
7070

7171
self.send_to_detokenization = context.socket(zmq.PUSH)
72-
self.send_to_detokenization.connect(f"tcp://127.0.0.1:{detokenization_port}")
72+
self.send_to_detokenization.connect(f"{args.zmq_mode}127.0.0.1:{detokenization_port}")
7373
self.model_rpc_ports = model_rpc_ports
7474

7575
self.is_splitfuse_mode = args.splitfuse_mode
@@ -283,14 +283,15 @@ async def _step(self):
283283
self.running_batch = new_batch
284284
await self._prefill_batch(self.running_batch)
285285
self._filter_runing_batch()
286-
self.has_wait_tokens = 0
286+
self.has_wait_tokens = self.max_wait_tokens
287287
return
288288

289289
# 有运行请求,但是已经到了可以调度新的请求合并推理的时机
290290
if self.has_wait_tokens >= self.max_wait_tokens:
291291
new_mini_batch = self.req_queue.generate_new_batch(self.running_batch)
292292
self.has_wait_tokens = 0
293293
if new_mini_batch is not None:
294+
self.has_wait_tokens = self.max_wait_tokens
294295
self.stats_tool.count_prompt_tokens(new_mini_batch)
295296
await self._prefill_batch(new_mini_batch)
296297
if not new_mini_batch.is_clear():
@@ -426,6 +427,9 @@ def _update_init_status_to_batch(self, batch: Batch, req_to_req_status):
426427

427428
def _update_out_status_to_batch(self, batch: Batch, req_to_out_status):
428429
new_batch_decode_need_tokens = [0 for _ in range(self.dp_size)] # 只有在 splitfuse 模式下有意义
430+
431+
start_time = 0
432+
# extral_info 字段如果推理后端输入时间标记, 则用来评估序列化所占用的时间, 主要用于调试时使用
429433
for req_id, (
430434
req_status,
431435
cur_kv_len,
@@ -434,6 +438,8 @@ def _update_out_status_to_batch(self, batch: Batch, req_to_out_status):
434438
finish_status_value,
435439
extral_info,
436440
) in req_to_out_status.items():
441+
if extral_info is not None:
442+
start_time = max(start_time, extral_info)
437443
req: Req = batch.id_to_reqs[req_id]
438444
req.req_status = req_status
439445
req.cur_kv_len = cur_kv_len
@@ -446,6 +452,9 @@ def _update_out_status_to_batch(self, batch: Batch, req_to_out_status):
446452
new_batch_decode_need_tokens[req_dp_index] += req.get_decode_need_tokens()
447453

448454
batch.batch_decode_need_tokens = new_batch_decode_need_tokens
455+
rpyc_cost_time = (time.time() - start_time) * 1000
456+
if 8 <= rpyc_cost_time <= 1000:
457+
logger.warning(f"rpyc use too much time {rpyc_cost_time} ms, batch_size {len(req_to_out_status)}")
449458
return
450459

451460
def _can_decode(self, batch: Batch):

lightllm/server/router/model_infer/mode_backend/continues_batch/decode_node_impl/decode_impl.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
import torch
23
import torch.multiprocessing as mp
34
import torch.distributed as dist
@@ -30,7 +31,13 @@ def init_custom(self):
3031
self.lock_nccl_group = dist.new_group(backend="gloo")
3132
from .decode_infer_rpyc import PDDecodeInferRpcServer
3233

33-
t = ThreadedServer(PDDecodeInferRpcServer(self), port=self.pd_rpyc_port, protocol_config={"allow_pickle": True})
34+
socket_path = f"/tmp/decode_node_infer_rpyc_{self.pd_rpyc_port}"
35+
if os.path.exists(socket_path):
36+
os.remove(socket_path)
37+
38+
t = ThreadedServer(
39+
PDDecodeInferRpcServer(self), socket_path=socket_path, protocol_config={"allow_pickle": True}
40+
)
3441
threading.Thread(target=lambda: t.start(), daemon=True).start()
3542
return
3643

lightllm/server/router/model_infer/mode_backend/continues_batch/decode_node_impl/decode_kv_move_manager.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,10 @@ def __init__(self, args, info_queue: mp.Queue, mem_queues: List[mp.Queue]):
8787
self.infer_rpyc_objs: List[PDDecodeInferRpcServer] = []
8888
self.node_id_to_trans_obj: Dict[str, TransProcessObj] = {}
8989
for port in self.args.pd_tp_infer_rpyc_ports:
90-
con = retry(max_attempts=20, wait_time=2)(rpyc.connect)("localhost", port, config={"allow_pickle": True})
90+
socket_path = f"/tmp/decode_node_infer_rpyc_{port}"
91+
from rpyc.utils.factory import unix_connect
92+
93+
con = retry(max_attempts=20, wait_time=2)(unix_connect)(socket_path, config={"allow_pickle": True})
9194
self.infer_rpyc_objs.append(con.root)
9295
logger.info(f"rpyc connect to port: {port} ok")
9396

0 commit comments

Comments
 (0)