Skip to content

Commit f710f51

Browse files
authored
check_env in multiprocess (#3879)
1 parent 01b52f9 commit f710f51

File tree

3 files changed

+43
-10
lines changed

3 files changed

+43
-10
lines changed

lmdeploy/pytorch/engine/engine_checker.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,3 +96,27 @@ def check(self):
9696
message='num_gpu_blocks should be greater than 16, '
9797
f'but got {num_gpu_blocks}. Set num_gpu_blocks to 0 to automatically '
9898
'determine the number of GPU blocks based on the model size and device memory.')
99+
100+
def _handle_impl(self):
101+
return super().handle()
102+
103+
def handle(self):
104+
import multiprocessing as mp
105+
from concurrent.futures import ProcessPoolExecutor
106+
107+
from lmdeploy.pytorch import envs
108+
if not envs.enable_check_env:
109+
return
110+
111+
current_proc = mp.current_process()
112+
if not current_proc.daemon:
113+
mp_ctx = mp.get_context('spawn')
114+
with ProcessPoolExecutor(mp_context=mp_ctx) as executor:
115+
try:
116+
executor.submit(self._handle_impl).result()
117+
except SystemExit:
118+
exit(1)
119+
except BaseException as e:
120+
self.log_and_exit(e, mod_name='Engine')
121+
else:
122+
return self._handle_impl()

lmdeploy/pytorch/engine/mp_engine/zmq_engine.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
import asyncio
3+
import atexit
34
import pickle
45
import signal
56
from typing import TYPE_CHECKING
@@ -40,6 +41,7 @@ def __init__(self, model_path: str, tokenizer: object, engine_config: PytorchEng
4041
self.rpc_client = AsyncRPCClient(port=self.port)
4142

4243
super().__init__()
44+
atexit.register(self.close)
4345

4446
def _start_mp_proc(self, model_path: str, tokenizer: object, engine_config: PytorchEngineConfig = None):
4547
"""Start mp proc."""
@@ -54,16 +56,17 @@ def _start_mp_proc(self, model_path: str, tokenizer: object, engine_config: Pyto
5456
condition = manager.Condition()
5557
self.mp_ctx = mp.get_context('spawn')
5658
log_level = logger.level
57-
self.proc = self.mp_ctx.Process(target=self._mp_proc,
58-
args=(self.shared_dict, condition),
59-
kwargs=(dict(
60-
model_path=model_path,
61-
tokenizer=tokenizer,
62-
engine_config=engine_config,
63-
log_level=log_level,
64-
)),
65-
name='mp_engine_proc',
66-
daemon=True)
59+
self.proc = self.mp_ctx.Process(
60+
target=self._mp_proc,
61+
args=(self.shared_dict, condition),
62+
kwargs=(dict(
63+
model_path=model_path,
64+
tokenizer=tokenizer,
65+
engine_config=engine_config,
66+
log_level=log_level,
67+
)),
68+
name='mp_engine_proc',
69+
)
6770
self.proc.start()
6871
logger.debug('Receiving rpc server port from mp process.')
6972
with condition:
@@ -156,10 +159,13 @@ async def _collective_rpc_streaming_async(self, func, *args, **kwargs):
156159

157160
def close(self) -> None:
158161
"""Close mp engine."""
162+
if self.proc is None:
163+
return
159164
logger.info('Closing mp engine.')
160165
self.rpc_client.stop()
161166
self.proc.terminate()
162167
self.proc.join(10)
168+
self.proc = None
163169

164170
def start_loop(self) -> None:
165171
"""Start mp engine loop."""

lmdeploy/pytorch/envs.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,9 @@ def _patched_get_env(
118118
# logging
119119
log_file = os.getenv('LMDEPLOY_LOG_FILE', None)
120120

121+
# check env
122+
enable_check_env = env_to_bool('LMDEPLOY_ENABLE_CHECK_ENV', True)
123+
121124
# dlblas
122125
# we don't need to read this, it would be passed to ray workers
123126
# If Ray is launched from outside, it may fail to access the environment variables.

0 commit comments

Comments
 (0)