Skip to content

Commit db1f8c1

Browse files
committed
reformat.
1 parent 9560b20 commit db1f8c1

File tree

10 files changed

+76
-61
lines changed

10 files changed

+76
-61
lines changed

lightllm/server/api_start.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -283,10 +283,10 @@ def pd_master_start(args):
283283
set_unique_server_name(args)
284284
if args.run_mode != "pd_master":
285285
return
286-
286+
287287
# when use config_server to support multi pd_master node, we
288288
# need generate unique node id for each pd_master node.
289-
# otherwise, we use the 0 for single pd_master node.
289+
# otherwise, we use the 0 for single pd_master node.
290290
if args.config_server_host and args.config_server_port:
291291
args.pd_node_id = uuid.uuid4().int
292292
else:
@@ -344,7 +344,7 @@ def config_server_start(args):
344344
set_unique_server_name(args)
345345
if args.run_mode != "config_server":
346346
return
347-
347+
348348
logger.info(f"all start args:{args}")
349349

350350
set_env_start_args(args)

lightllm/server/config_server/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,4 @@
99
- Global configuration query: Provides mechanisms for querying shared
1010
configuration data across the system.
1111
- Designed for distributed systems operating in PD separation mode.
12-
"""
12+
"""

lightllm/server/config_server/api_http.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
logger = init_logger(__name__)
1313
app = FastAPI()
1414

15-
registered_pd_master_objs:Dict[str, PD_Master_Obj] = {}
15+
registered_pd_master_objs: Dict[str, PD_Master_Obj] = {}
1616
registered_pd_master_obj_lock = Lock()
1717

1818
global_req_id = 0
@@ -61,13 +61,15 @@ async def websocket_endpoint(websocket: WebSocket):
6161
registered_pd_master_objs.pop(registered_pd_master_obj.node_id, None)
6262
return
6363

64+
6465
@app.get("/registered_objects")
6566
async def get_registered_objects():
6667
with registered_pd_master_obj_lock:
6768
serialized_data = pickle.dumps(registered_pd_master_objs)
68-
base64_encoded = base64.b64encode(serialized_data).decode('utf-8')
69+
base64_encoded = base64.b64encode(serialized_data).decode("utf-8")
6970
return {"data": base64_encoded}
7071

72+
7173
@app.get("/allocate_global_unique_id_range")
7274
async def allocate_global_id_range():
7375
"""
@@ -85,8 +87,8 @@ async def allocate_global_id_range():
8587
global global_req_id
8688
range_size = 800000
8789
with global_req_id_lock:
88-
if global_req_id + range_size > 2**63 - 1:
89-
global_req_id = 0
90+
if global_req_id + range_size > 2 ** 63 - 1:
91+
global_req_id = 0
9092
start_id = global_req_id
9193
global_req_id += range_size
9294
end_id = global_req_id

lightllm/server/httpserver/pd_loop.py

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
logger = init_logger(__name__)
1717

18+
1819
async def timer_log(manager: HttpServerManager):
1920
while True:
2021
await asyncio.sleep(30)
@@ -32,13 +33,13 @@ async def pd_handle_loop(manager: HttpServerManager):
3233

3334
asyncio.create_task(timer_log(manager))
3435

35-
id_to_handle_task:Dict[int, asyncio.Task] = {}
36+
id_to_handle_task: Dict[int, asyncio.Task] = {}
3637

3738
while True:
3839
try:
3940
id_to_pd_master_obj = await _get_pd_master_objs(manager.args)
4041
logger.info(f"get pd_master_objs {id_to_pd_master_obj}")
41-
42+
4243
if id_to_pd_master_obj is not None:
4344
for node_id, pd_master_obj in id_to_handle_task.items():
4445
if node_id not in id_to_pd_master_obj:
@@ -51,7 +52,7 @@ async def pd_handle_loop(manager: HttpServerManager):
5152
id_to_handle_task[node_id] = asyncio.create_task(_pd_handle_task(manager, pd_master_obj))
5253

5354
await asyncio.sleep(30)
54-
55+
5556
except Exception as e:
5657
logger.exception(str(e))
5758
await asyncio.sleep(10)
@@ -70,7 +71,7 @@ async def _pd_handle_task(manager: HttpServerManager, pd_master_obj: PD_Master_O
7071
try:
7172
uri = f"ws://{pd_master_obj.host_ip_port}/pd_register"
7273
async with websockets.connect(uri, max_queue=(2048 * 1024, 2048 * 1023)) as websocket:
73-
74+
7475
sock = websocket.transport.get_extra_info("socket")
7576
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
7677

@@ -83,35 +84,35 @@ async def _pd_handle_task(manager: HttpServerManager, pd_master_obj: PD_Master_O
8384
"mode": manager.pd_mode.value,
8485
"start_args": args_dict,
8586
}
86-
87+
8788
await websocket.send(json.dumps(regist_json))
8889
logger.info(f"Sent registration JSON: {regist_json}")
89-
90+
9091
# 转发任务
91-
forwarding_tokens_task = asyncio.create_task(
92-
_up_tokens_to_pd_master(forwarding_queue, websocket)
93-
)
94-
92+
forwarding_tokens_task = asyncio.create_task(_up_tokens_to_pd_master(forwarding_queue, websocket))
93+
9594
# 接收 pd master 发来的请求,并推理后,将生成的token转发回pd master。
9695
while True:
9796
recv_bytes = await websocket.recv()
9897
obj = pickle.loads(recv_bytes)
9998
if obj[0] == ObjType.REQ:
10099
prompt, sampling_params, multimodal_params = obj[1]
101-
asyncio.create_task(_pd_process_generate(manager, prompt, sampling_params, multimodal_params, forwarding_queue))
100+
asyncio.create_task(
101+
_pd_process_generate(manager, prompt, sampling_params, multimodal_params, forwarding_queue)
102+
)
102103
elif obj[0] == ObjType.ABORT:
103104
group_req_id = obj[1]
104105
await manager.abort(group_req_id)
105106
else:
106107
logger.error(f"recevie error obj {str(obj)}")
107-
108+
108109
except asyncio.CancelledError:
109110
# 如果任务被取消,则退出循环
110111
logger.warning(f"forwarding_tokens_task {pd_master_obj} cancelled")
111112
if forwarding_tokens_task is not None:
112113
forwarding_tokens_task.cancel()
113114
return
114-
115+
115116
except Exception as e:
116117
logger.error("connetion to pd_master has error")
117118
logger.exception(str(e))
@@ -122,7 +123,7 @@ async def _pd_handle_task(manager: HttpServerManager, pd_master_obj: PD_Master_O
122123
logger.info("reconnection to pd_master")
123124

124125

125-
async def _get_pd_master_objs(args)->Optional[Dict[int, PD_Master_Obj]]:
126+
async def _get_pd_master_objs(args) -> Optional[Dict[int, PD_Master_Obj]]:
126127
"""
127128
get_pd_master_objs 主要负责从 pd master 获取所有的pd master对象。
128129
"""
@@ -135,15 +136,15 @@ async def _get_pd_master_objs(args)->Optional[Dict[int, PD_Master_Obj]]:
135136
ans = dict()
136137
ans[0] = PD_Master_Obj(node_id=0, host_ip_port=f"{args.pd_master_ip}:{args.pd_master_port}")
137138
return ans
138-
139+
139140
# 使用 config_server 服务来发现所有的 pd_master 节点。
140141
uri = f"ws://{args.config_server_host}:{args.config_server_port}/registered_objects"
141142

142143
try:
143144
async with httpx.AsyncClient() as client:
144145
response = await client.get(uri)
145146
if response.status_code == 200:
146-
base64data = response.json()["data"]
147+
base64data = response.json()["data"]
147148
id_to_pd_master_obj = pickle.loads(base64.b64decode(base64data))
148149
return id_to_pd_master_obj
149150
else:
@@ -154,8 +155,11 @@ async def _get_pd_master_objs(args)->Optional[Dict[int, PD_Master_Obj]]:
154155
await asyncio.sleep(10)
155156
return None
156157

158+
157159
# 触发推理的task
158-
async def _pd_process_generate(manager: HttpServerManager, prompt, sampling_params, multimodal_params, forwarding_queue:AsyncQueue):
160+
async def _pd_process_generate(
161+
manager: HttpServerManager, prompt, sampling_params, multimodal_params, forwarding_queue: AsyncQueue
162+
):
159163
try:
160164
async for sub_req_id, request_output, metadata, finish_status in manager.generate(
161165
prompt, sampling_params, multimodal_params, None
@@ -175,4 +179,3 @@ async def _up_tokens_to_pd_master(forwarding_queue: AsyncQueue, websocket):
175179
handle_list = await forwarding_queue.wait_to_get_all_data()
176180
if handle_list:
177181
await websocket.send(pickle.dumps((ObjType.TOKEN_PACKS, handle_list)))
178-

lightllm/server/httpserver_for_pd_master/register_loop.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
logger = init_logger(__name__)
1212

13+
1314
async def register_loop(manager: HttpServerManagerForPDMaster):
1415
assert manager.args.host not in ["127.0.0.1", "localhost"], "pd mode must specify host ip"
1516

@@ -23,15 +24,17 @@ async def register_loop(manager: HttpServerManagerForPDMaster):
2324
try:
2425
uri = f"ws://{manager.args.config_server_host}:{manager.args.config_server_port}/pd_master_register"
2526
async with websockets.connect(uri, max_queue=(2048 * 1024, 2048 * 1023)) as websocket:
26-
27+
2728
sock = websocket.transport.get_extra_info("socket")
2829
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
2930

30-
pd_master_obj = PD_Master_Obj(node_id=manager.args.pd_node_id, host_ip_port=f"{manager.host_ip}:{manager.args.port}")
31+
pd_master_obj = PD_Master_Obj(
32+
node_id=manager.args.pd_node_id, host_ip_port=f"{manager.host_ip}:{manager.args.port}"
33+
)
3134

3235
await websocket.send(pickle.dumps(pd_master_obj))
3336
logger.info(f"Sent registration pd_master obj: {pd_master_obj}")
34-
37+
3538
while True:
3639
await websocket.send("heartbeat")
3740
await asyncio.sleep(60)
@@ -41,4 +44,3 @@ async def register_loop(manager: HttpServerManagerForPDMaster):
4144
logger.exception(str(e))
4245
await asyncio.sleep(10)
4346
logger.info("reconnection to config_server")
44-

lightllm/server/pd_io_struct.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,15 +46,17 @@ def __post_init__(self):
4646

4747
def to_llm_url(self):
4848
return f"http://{self.client_ip_port}/pd_generate_stream"
49-
49+
50+
5051
@dataclass
5152
class PD_Master_Obj:
5253
node_id: int
5354
host_ip_port: str
54-
55+
5556
def to_log_str(self):
5657
return f"PD_MASTER host_ip_port: {self.host_ip_port} node_id: {self.node_id}"
5758

59+
5860
@dataclass
5961
class UpKVStatus:
6062
type: str = "kv_move_status"
@@ -73,7 +75,7 @@ def __post_init__(self):
7375
error_info = "group_request_id only can be int"
7476
logger.error(error_info)
7577
raise ValueError(error_info)
76-
78+
7779
if not isinstance(self.pd_master_node_id, int):
7880
error_info = "pd_master_node_id only can be int"
7981
logger.error(error_info)
@@ -163,4 +165,4 @@ def get_cost_time(self):
163165
@dataclass
164166
class KVMoveTaskGroup:
165167
tasks: List[KVMoveTask]
166-
connect_id: str
168+
connect_id: str

lightllm/server/req_id_generator.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,11 @@ def __init__(self):
1717
from lightllm.server.core.objs.atomic_lock import AtomicShmLock
1818
from lightllm.server.core.objs.shm_array import ShmArray
1919
from lightllm.utils.envs_utils import get_unique_server_name, get_env_start_args
20-
20+
2121
self.args = get_env_start_args()
22-
self.use_config_server = self.args.config_server_host and self.args.config_server_port and self.args.run_mode == "pd_master"
22+
self.use_config_server = (
23+
self.args.config_server_host and self.args.config_server_port and self.args.run_mode == "pd_master"
24+
)
2325
self.current_id = ShmArray(f"{get_unique_server_name()}_req_id_gen", (2,), dtype=np.int64)
2426
self.current_id.create_shm()
2527
self.current_id.arr[0] = 0
@@ -43,7 +45,9 @@ def _check_and_set_new_id_range(self):
4345
# 保证id满足倍乘关系
4446
self.current_id.arr[0] = (id_range["start_id"] // MAX_BEST_OF + 1) * MAX_BEST_OF
4547
self.current_id.arr[1] = id_range["end_id"]
46-
assert self.current_id.arr[0] + MAX_BEST_OF < self.current_id.arr[1], f"get id range error {self.current_id.arr[0]} {self.current_id.arr[1]}"
48+
assert (
49+
self.current_id.arr[0] + MAX_BEST_OF < self.current_id.arr[1]
50+
), f"get id range error {self.current_id.arr[0]} {self.current_id.arr[1]}"
4751
return
4852
else:
4953
raise RuntimeError(f"Failed to fetch ID range from config server: {response.status_code}")

lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_obj.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -160,9 +160,11 @@ def put_to_radix_loop(self):
160160
f"{func_name} put kv to radix cache ok, req_id: {task.id()} cost_time {task.get_cost_time()} s"
161161
)
162162
self.manager.up_status_in_queue.put(
163-
UpKVStatus(group_request_id=task.group_request_id,
164-
dp_index=task.decode_dp_index,
165-
pd_master_node_id=task.decode_node.pd_master_node_id)
163+
UpKVStatus(
164+
group_request_id=task.group_request_id,
165+
dp_index=task.decode_dp_index,
166+
pd_master_node_id=task.decode_node.pd_master_node_id,
167+
)
166168
)
167169
logger.info(f"{func_name} up kv status req_id: {task.id()} finished")
168170
move_tasks.clear()

lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/up_status.py

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -26,21 +26,21 @@ def __init__(self, args, task_in_queue: mp.Queue, task_out_queue: mp.Queue):
2626

2727
def thread_loop(self):
2828
asyncio.run(self.task_loop())
29-
29+
3030
async def task_loop(self):
3131

32-
self.id_to_handle_task:Dict[int, asyncio.Task] = {}
33-
self.id_to_handle_queue:Dict[int, asyncio.Queue] = {}
32+
self.id_to_handle_task: Dict[int, asyncio.Task] = {}
33+
self.id_to_handle_queue: Dict[int, asyncio.Queue] = {}
3434

3535
asyncio.create_task(self.dispatch_task_loop())
3636

3737
while True:
3838
try:
3939
from lightllm.server.httpserver.pd_loop import _get_pd_master_objs
40-
40+
4141
id_to_pd_master_obj = await _get_pd_master_objs(self.args)
4242
logger.info(f"get pd_master_objs {id_to_pd_master_obj}")
43-
43+
4444
if id_to_pd_master_obj is not None:
4545
for node_id, pd_master_obj in self.id_to_handle_task.items():
4646
if node_id not in id_to_pd_master_obj:
@@ -55,23 +55,23 @@ async def task_loop(self):
5555
self.id_to_handle_task[node_id] = asyncio.create_task(self.up_kv_status_task(pd_master_obj))
5656

5757
await asyncio.sleep(30)
58-
58+
5959
except Exception as e:
6060
logger.exception(str(e))
6161
await asyncio.sleep(10)
6262

6363
async def dispatch_task_loop(self):
64-
while True:
65-
try:
66-
loop = asyncio.get_event_loop()
67-
upkv_status: UpKVStatus = await loop.run_in_executor(None, self.task_queue.get)
68-
if upkv_status.pd_master_node_id in self.id_to_handle_queue:
69-
await self.id_to_handle_queue[upkv_status.pd_master_node_id].put(upkv_status)
70-
else:
71-
logger.warning(f"upstatus {upkv_status} no connection to pd_master, drop it")
72-
except BaseException as e:
73-
logger.exception(str(e))
74-
await asyncio.sleep(10)
64+
while True:
65+
try:
66+
loop = asyncio.get_event_loop()
67+
upkv_status: UpKVStatus = await loop.run_in_executor(None, self.task_queue.get)
68+
if upkv_status.pd_master_node_id in self.id_to_handle_queue:
69+
await self.id_to_handle_queue[upkv_status.pd_master_node_id].put(upkv_status)
70+
else:
71+
logger.warning(f"upstatus {upkv_status} no connection to pd_master, drop it")
72+
except BaseException as e:
73+
logger.exception(str(e))
74+
await asyncio.sleep(10)
7575

7676
async def up_kv_status_task(self, pd_master_obj: PD_Master_Obj):
7777
while True:
@@ -98,15 +98,14 @@ async def up_kv_status_task(self, pd_master_obj: PD_Master_Obj):
9898
except asyncio.CancelledError:
9999
logger.info(f"up_kv_status_task {pd_master_obj} cancelled")
100100
return
101-
101+
102102
except Exception as e:
103103
logger.error(f"connetion to pd_master {pd_master_obj} has error: {str(e)}")
104104
logger.exception(str(e))
105105
await asyncio.sleep(10)
106106
logger.info("reconnection to pd_master")
107107

108108

109-
110109
def _init_env(args, task_in_queue: mp.Queue, task_out_queue: mp.Queue):
111110
graceful_registry(inspect.currentframe().f_code.co_name)
112111
up_kv_manager = UpStatusManager(args, task_in_queue, task_out_queue)

0 commit comments

Comments
 (0)