|
| 1 | +import time |
| 2 | +import asyncio |
| 3 | +import base64 |
| 4 | +import pickle |
| 5 | +import multiprocessing as mp |
1 | 6 | from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Request, Query |
2 | 7 | from threading import Lock |
3 | | -from typing import Dict |
| 8 | +from typing import Dict, List |
4 | 9 | from fastapi.responses import JSONResponse |
5 | 10 | from lightllm.utils.log_utils import init_logger |
6 | 11 | from ..pd_io_struct import PD_Master_Obj |
7 | | -import base64 |
8 | | -import pickle |
9 | | -import os |
10 | | -import requests |
| 12 | +from .nccl_tcp_store import start_tcp_store_server |
| 13 | +from lightllm.utils.envs_utils import get_env_start_args |
| 14 | +from lightllm.utils.process_check import start_parent_check_thread |
| 15 | + |
11 | 16 |
|
12 | 17 | logger = init_logger(__name__) |
13 | 18 | app = FastAPI() |
@@ -112,3 +117,90 @@ async def allocate_global_unique_multimodal_id_range(): |
112 | 117 | end_id = global_multimodal_embedding_id |
113 | 118 |
|
114 | 119 | return {"start_id": start_id, "end_id": end_id} |
| 120 | + |
| 121 | + |
| 122 | +global_store_port_to_process: Dict[int, mp.Process] = {} |
| 123 | +global_store_port_to_client_states: Dict[int, List[bool]] = {} |
| 124 | +global_store_port_lock = asyncio.Lock() |
| 125 | + |
| 126 | + |
| 127 | +@app.get("/start_tcp_store_server") |
| 128 | +async def http_start_tcp_store_server( |
| 129 | + tcp_store_port: int = Query(...), rank_id: int = Query(...), world_size: int = Query(...) |
| 130 | +): |
| 131 | + """ |
| 132 | + Start a TCP store server for NCCL communication. |
| 133 | +
|
| 134 | + Args: |
| 135 | + tcp_store_port (int): The port number for the TCP store server. |
| 136 | + rank_id (int): The rank ID of inference process. |
| 137 | + world_size (int): The world size of nccl group. |
| 138 | +
|
| 139 | + Returns: |
| 140 | + dict: A dictionary containing the status of the server. |
| 141 | + """ |
| 142 | + global global_store_port_to_process |
| 143 | + global global_store_port_to_client_states |
| 144 | + global global_store_port_lock |
| 145 | + |
| 146 | + args = get_env_start_args() |
| 147 | + |
| 148 | + if rank_id == 0: |
| 149 | + async with global_store_port_lock: |
| 150 | + if tcp_store_port in global_store_port_to_client_states: |
| 151 | + logger.error(f"tcp store server {tcp_store_port} already started, rank_id 0 find client state exists") |
| 152 | + assert False, f"tcp store server {tcp_store_port} already started, rank_id 0 find client state exists" |
| 153 | + |
| 154 | + if tcp_store_port in global_store_port_to_process: |
| 155 | + logger.warning(f"tcp store server {tcp_store_port} already started, kill and restart it") |
| 156 | + process = global_store_port_to_process[tcp_store_port] |
| 157 | + process.kill() |
| 158 | + process.join() |
| 159 | + |
| 160 | + global_store_port_to_process[tcp_store_port] = start_tcp_store_server( |
| 161 | + args.config_server_host, tcp_store_port |
| 162 | + ) |
| 163 | + |
| 164 | + world_size_state = [True for _ in range(world_size)] |
| 165 | + global_store_port_to_client_states[tcp_store_port] = world_size_state |
| 166 | + |
| 167 | + world_size_state[rank_id] = False |
| 168 | + |
| 169 | + start_time = time.time() |
| 170 | + while any(world_size_state): |
| 171 | + await asyncio.sleep(1) |
| 172 | + if time.time() - start_time > 60 * 3: |
| 173 | + logger.error( |
| 174 | + f"tcp store server {tcp_store_port} rank_id {rank_id} world_size {world_size} wait all quit timeout" |
| 175 | + ) |
| 176 | + async with global_store_port_lock: |
| 177 | + global_store_port_to_client_states.pop(tcp_store_port, None) |
| 178 | + raise Exception( |
| 179 | + f"tcp store server {tcp_store_port} rank_id {rank_id} world_size {world_size} wait timeout" |
| 180 | + ) |
| 181 | + |
| 182 | + async with global_store_port_lock: |
| 183 | + global_store_port_to_client_states.pop(tcp_store_port, None) |
| 184 | + |
| 185 | + return {"status": "ok"} |
| 186 | + else: |
| 187 | + start_time = time.time() |
| 188 | + while tcp_store_port not in global_store_port_to_client_states: |
| 189 | + await asyncio.sleep(1) |
| 190 | + if time.time() - start_time > 60 * 3: |
| 191 | + logger.error(f"tcp store port {tcp_store_port} rank_id {rank_id} world_size {world_size} state timeout") |
| 192 | + raise Exception( |
| 193 | + f"tcp store server {tcp_store_port} rank_id {rank_id} world_size {world_size} state timeout" |
| 194 | + ) |
| 195 | + |
| 196 | + world_size_state = global_store_port_to_client_states[tcp_store_port] |
| 197 | + |
| 198 | + assert ( |
| 199 | + world_size_state[rank_id] is True |
| 200 | + ), f"tcp store server {tcp_store_port} rank_id {rank_id} world_size {world_size} world_size_state error" |
| 201 | + world_size_state[rank_id] = False |
| 202 | + return {"status": "ok"} |
| 203 | + |
| 204 | + |
| 205 | +logger.info("config server start_parent_check_thread...") |
| 206 | +start_parent_check_thread() |
0 commit comments