Skip to content

Commit d478151

Browse files
committed
add mps for multimodal
1 parent 6c46415 commit d478151

File tree

6 files changed

+150
-6
lines changed

6 files changed

+150
-6
lines changed

lightllm/server/api_cli.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,9 @@ def make_argument_parser() -> argparse.ArgumentParser:
193193
parser.add_argument(
194194
"--enable_multimodal", action="store_true", help="Whether or not to allow to load additional multimodal models."
195195
)
196+
parser.add_argument(
197+
"--enable_mps", action="store_true", help="Whether to enable nvidia mps for multimodal service."
198+
)
196199
parser.add_argument("--enable_custom_allreduce", action="store_true", help="Whether to disable cutom allreduce.")
197200
parser.add_argument("--enable_custom_allgather", action="store_true", help="Whether to enable cutom allgather.")
198201
parser.add_argument(

lightllm/server/api_start.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,13 @@ def signal_handler(sig, frame):
6363
def normal_or_p_d_start(args):
6464
set_unique_server_name(args)
6565

66+
if args.enable_mps:
67+
from lightllm.utils.device_utils import enable_mps, set_gpu_exclusive_mode
68+
69+
enable_mps()
70+
for i in range(args.tp):
71+
set_gpu_exclusive_mode(gpu_index=i)
72+
6673
if args.run_mode not in ["normal", "prefill", "decode"]:
6774
return
6875

lightllm/server/visualserver/manager.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,10 @@ async def wait_to_model_ready(self):
5555
for dp_rank_id in range(self.vit_dp):
5656
tp_ports_each_dp = self.visual_model_rpc_ports[dp_rank_id]
5757
for tp_rank_id in range(self.vit_tp):
58-
rpc_model = await start_model_process(port=tp_ports_each_dp[tp_rank_id], vit_tp=self.vit_tp)
58+
device_id = dp_rank_id * self.vit_tp + tp_rank_id
59+
rpc_model = await start_model_process(
60+
port=tp_ports_each_dp[tp_rank_id], vit_tp=self.vit_tp, device_id=device_id
61+
)
5962
self.model_rpcs[dp_rank_id].append(rpc_model)
6063

6164
init_model_ret = []
@@ -159,7 +162,6 @@ def start_visual_process(args, router_port, visual_port, cache_port, model_rpc_p
159162
# 注册graceful 退出的处理
160163
graceful_registry(inspect.currentframe().f_code.co_name)
161164
start_parent_check_thread()
162-
163165
try:
164166
visualserver = VisualManager(args, router_port, visual_port, cache_port, model_rpc_ports)
165167
asyncio.run(visualserver.wait_to_model_ready())

lightllm/server/visualserver/model_infer/model_rpc.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -139,19 +139,28 @@ async def encode(self, images: List[ImageItem]):
139139
return ans
140140

141141

142-
def _init_env(port):
142+
def _init_env(port, device_id):
143143
# 注册graceful 退出的处理
144144
graceful_registry(inspect.currentframe().f_code.co_name)
145+
from lightllm.utils.device_utils import set_sm_limit
146+
147+
set_sm_limit(60, device_id) # the visual server can take up to 60% of the gpu memory
145148

146149
t = ThreadedServer(VisualModelRpcServer(), port=port, protocol_config={"allow_pickle": True})
147150
t.start()
148151
return
149152

150153

151-
async def start_model_process(port, vit_tp):
154+
async def start_model_process(port, vit_tp, device_id):
152155
import multiprocessing
153156

154-
proc = multiprocessing.Process(target=_init_env, args=(port,))
157+
proc = multiprocessing.Process(
158+
target=_init_env,
159+
args=(
160+
port,
161+
device_id,
162+
),
163+
)
155164
proc.start()
156165
await asyncio.sleep(2)
157166
repeat_count = 0

lightllm/utils/device_utils.py

Lines changed: 113 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
import os
2-
from functools import lru_cache
2+
import time
3+
import shutil
34
import subprocess
5+
from functools import lru_cache
6+
from lightllm.utils.log_utils import init_logger
7+
8+
logger = init_logger(__name__)
49

510

611
@lru_cache(maxsize=None)
@@ -99,3 +104,110 @@ def has_nvlink():
99104
except subprocess.CalledProcessError:
100105
# If there's an error (e.g., nvidia-smi is not installed or another issue), assume no NVLink
101106
return False
107+
108+
109+
def is_mps_running(verbose=False):
110+
result = subprocess.run(
111+
"ps -ef | grep '[n]vidia-cuda-mps-control'",
112+
shell=True,
113+
stdout=subprocess.PIPE,
114+
stderr=subprocess.PIPE,
115+
text=True,
116+
)
117+
return result.returncode == 0
118+
119+
120+
def stop_mps():
121+
if is_mps_running():
122+
result = subprocess.run("echo quit | nvidia-cuda-mps-control", shell=True)
123+
logger.info("Stopping MPS...")
124+
if result.returncode == 0:
125+
logger.info("MPS stopped successfully.")
126+
else:
127+
logger.warning("Failed to stop MPS.")
128+
else:
129+
logger.info("MPS is not running, no need to stop.")
130+
131+
132+
def enable_mps():
133+
if is_mps_running():
134+
logger.info("MPS is already running, no need to start.")
135+
return
136+
137+
ret = os.system("nvidia-cuda-mps-control -d")
138+
139+
time.sleep(10)
140+
if ret != 0:
141+
logger.warning("Failed to start MPS.")
142+
return
143+
if is_mps_running():
144+
logger.info("MPS started successfully.")
145+
return
146+
147+
148+
def get_gpu_compute_mode(gpu_index=0):
149+
try:
150+
if not shutil.which("nvidia-smi"):
151+
logger.warning("nvidia-smi not found in PATH.")
152+
return None
153+
154+
cmd = ["nvidia-smi", "-i", str(gpu_index), "--query-gpu=compute_mode", "--format=csv,noheader"]
155+
result = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
156+
157+
if result.returncode != 0:
158+
logger.warning(f"Failed to query compute mode: {result.stderr.strip()}")
159+
return None
160+
161+
mode = result.stdout.strip()
162+
return mode
163+
164+
except Exception as e:
165+
logger.warning(f"Exception occurred while checking GPU compute mode: {e}")
166+
return None
167+
168+
169+
def set_gpu_exclusive_mode(gpu_index=0):
170+
logger.info(f"Setting GPU {gpu_index} to EXCLUSIVE_PROCESS mode...")
171+
result = subprocess.run(
172+
["nvidia-smi", "-i", str(gpu_index), "-c", "EXCLUSIVE_PROCESS"],
173+
stdout=subprocess.PIPE,
174+
stderr=subprocess.PIPE,
175+
text=True,
176+
)
177+
if result.returncode == 0:
178+
logger.info(f"GPU {gpu_index} set to EXCLUSIVE_PROCESS mode.")
179+
return True
180+
else:
181+
logger.warning(f"Failed to set EXCLUSIVE_PROCESS mode: {result.stderr.strip()}")
182+
return False
183+
184+
185+
def set_gpu_default_mode(gpu_index=0):
186+
logger.info(f"Setting GPU {gpu_index} to DEFAULT mode...")
187+
result = subprocess.run(
188+
["nvidia-smi", "-i", str(gpu_index), "-c", "DEFAULT"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True
189+
)
190+
if result.returncode == 0:
191+
logger.info(f"GPU {gpu_index} set to DEFAULT mode.")
192+
return True
193+
else:
194+
logger.warning(f"Failed to set DEFAULT mode: {result.stderr.strip()}")
195+
return False
196+
197+
198+
def set_sm_limit(percent: int, gpu_index=0):
199+
"""
200+
Sets CUDA_MPS_ACTIVE_THREAD_PERCENTAGE to the given value if the GPU is in EXCLUSIVE_PROCESS mode.
201+
"""
202+
if not (1 <= percent <= 100):
203+
logger.error("SM usage percentage must be between 1 and 100.")
204+
return False
205+
206+
mode = get_gpu_compute_mode(gpu_index)
207+
if mode != "Exclusive_Process":
208+
logger.warning(f"Cannot set SM limit. GPU {gpu_index} is in '{mode}' mode, not 'Exclusive_Process'.")
209+
return False
210+
211+
os.environ["CUDA_MPS_ACTIVE_THREAD_PERCENTAGE"] = str(percent)
212+
logger.info(f"Set CUDA_MPS_ACTIVE_THREAD_PERCENTAGE to {percent}% for GPU {gpu_index}.")
213+
return True

lightllm/utils/start_utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,11 @@ def start_submodule_processes(self, start_funcs=[], start_args=[]):
4141
return
4242

4343
def terminate_all_processes(self):
44+
from lightllm.utils.envs_utils import get_env_start_args
45+
46+
is_enable_mps = get_env_start_args().enable_mps
47+
world_size = get_env_start_args().tp
48+
4449
def kill_recursive(proc):
4550
try:
4651
parent = psutil.Process(proc.pid)
@@ -57,6 +62,12 @@ def kill_recursive(proc):
5762
if proc.is_alive():
5863
kill_recursive(proc)
5964
proc.join()
65+
if is_enable_mps:
66+
from lightllm.utils.device_utils import stop_mps, set_gpu_default_mode
67+
68+
stop_mps()
69+
for i in range(world_size):
70+
set_gpu_default_mode(gpu_index=i)
6071
logger.info("All processes terminated gracefully.")
6172

6273

0 commit comments

Comments
 (0)