11# Copyright (c) OpenMMLab. All rights reserved.
22import asyncio
3+ import atexit
34import pickle
45import signal
56from typing import TYPE_CHECKING
@@ -40,6 +41,7 @@ def __init__(self, model_path: str, tokenizer: object, engine_config: PytorchEng
4041 self .rpc_client = AsyncRPCClient (port = self .port )
4142
4243 super ().__init__ ()
44+ atexit .register (self .close )
4345
4446 def _start_mp_proc (self , model_path : str , tokenizer : object , engine_config : PytorchEngineConfig = None ):
4547 """Start mp proc."""
@@ -54,16 +56,17 @@ def _start_mp_proc(self, model_path: str, tokenizer: object, engine_config: Pyto
5456 condition = manager .Condition ()
5557 self .mp_ctx = mp .get_context ('spawn' )
5658 log_level = logger .level
57- self .proc = self .mp_ctx .Process (target = self ._mp_proc ,
58- args = (self .shared_dict , condition ),
59- kwargs = (dict (
60- model_path = model_path ,
61- tokenizer = tokenizer ,
62- engine_config = engine_config ,
63- log_level = log_level ,
64- )),
65- name = 'mp_engine_proc' ,
66- daemon = True )
59+ self .proc = self .mp_ctx .Process (
60+ target = self ._mp_proc ,
61+ args = (self .shared_dict , condition ),
62+ kwargs = (dict (
63+ model_path = model_path ,
64+ tokenizer = tokenizer ,
65+ engine_config = engine_config ,
66+ log_level = log_level ,
67+ )),
68+ name = 'mp_engine_proc' ,
69+ )
6770 self .proc .start ()
6871 logger .debug ('Receiving rpc server port from mp process.' )
6972 with condition :
@@ -156,10 +159,13 @@ async def _collective_rpc_streaming_async(self, func, *args, **kwargs):
156159
157160 def close (self ) -> None :
158161 """Close mp engine."""
162+ if self .proc is None :
163+ return
159164 logger .info ('Closing mp engine.' )
160165 self .rpc_client .stop ()
161166 self .proc .terminate ()
162167 self .proc .join (10 )
168+ self .proc = None
163169
164170 def start_loop (self ) -> None :
165171 """Start mp engine loop."""
0 commit comments