1212 get_current_placement_group ,
1313 placement_group )
1414
15+ from tensorrt_llm ._ray_utils import unwrap_ray_errors
1516from tensorrt_llm ._utils import get_free_port
1617from tensorrt_llm .logger import logger
1718
@@ -57,48 +58,54 @@ def __init__(self,
5758 "runtime_env" : runtime_env
5859 }
5960
60- if os .environ .get ("TLLM_RAY_FORCE_LOCAL_CLUSTER" , "0" ) != "1" :
61- try :
62- ray .init (address = "auto" , ** ray_init_args )
63- logger .info (f"Attached to an existing Ray cluster." )
64- except ConnectionError :
65- logger .info (f"Ray cluster not found, starting a new one." )
66-
67- if not ray .is_initialized ():
68- ray .init (** ray_init_args )
61+ try :
62+ if os .environ .get ("TLLM_RAY_FORCE_LOCAL_CLUSTER" , "0" ) != "1" :
63+ try :
64+ ray .init (address = "auto" , ** ray_init_args )
65+ logger .info (f"Attached to an existing Ray cluster." )
66+ except ConnectionError :
67+ logger .info (f"Ray cluster not found, starting a new one." )
68+
69+ if not ray .is_initialized ():
70+ ray .init (** ray_init_args )
71+ self .has_start_local_cluser = True
72+ else :
73+ ray .init (address = "local" , ** ray_init_args )
6974 self .has_start_local_cluser = True
70- else :
71- ray .init (address = "local" , ** ray_init_args )
72- self .has_start_local_cluser = True
7375
74- self .world_size = model_world_size
75- self .tp_size = tp_size
76- self .master_address = ray .util .get_node_ip_address ()
77- self .master_port = get_free_port ()
78-
79- self .response_queue = RayAsyncQueue .options (runtime_env = {
80- "env_vars" : {
81- "TLLM_DISABLE_MPI" : "1"
82- }
83- }).remote ()
84- self .response_sync_queue = RaySyncQueue .options (runtime_env = {
85- "env_vars" : {
86- "TLLM_DISABLE_MPI" : "1"
87- }
88- }).remote ()
89- self .async_response_queue_weakref = self .create_actor_weak_ref (
90- self .response_queue )
91- self .sync_response_queue_weakref = self .create_actor_weak_ref (
92- self .response_sync_queue )
93- self .response_queue .warmup .remote ()
94- self .response_sync_queue .warmup .remote ()
95-
96- worker_kwargs = dict (** worker_kwargs ,
97- postproc_worker_config = postproc_worker_config ,
98- is_llm_executor = is_llm_executor ,
99- kv_connector_config = kv_connector_config )
100-
101- self .create_workers (RayGPUWorker , worker_kwargs )
76+ self .world_size = model_world_size
77+ self .tp_size = tp_size
78+ self .master_address = ray .util .get_node_ip_address ()
79+ self .master_port = get_free_port ()
80+
81+ self .response_queue = RayAsyncQueue .options (runtime_env = {
82+ "env_vars" : {
83+ "TLLM_DISABLE_MPI" : "1"
84+ }
85+ }).remote ()
86+ self .response_sync_queue = RaySyncQueue .options (runtime_env = {
87+ "env_vars" : {
88+ "TLLM_DISABLE_MPI" : "1"
89+ }
90+ }).remote ()
91+ self .async_response_queue_weakref = self .create_actor_weak_ref (
92+ self .response_queue )
93+ self .sync_response_queue_weakref = self .create_actor_weak_ref (
94+ self .response_sync_queue )
95+ self .response_queue .warmup .remote ()
96+ self .response_sync_queue .warmup .remote ()
97+
98+ worker_kwargs = dict (** worker_kwargs ,
99+ postproc_worker_config = postproc_worker_config ,
100+ is_llm_executor = is_llm_executor ,
101+ kv_connector_config = kv_connector_config )
102+
103+ self .create_workers (RayGPUWorker , worker_kwargs )
104+ except Exception as e :
105+ # Clean up the Ray resources early during exception
106+ self .shutdown ()
107+ logger .error (f"Failed to initialize RayExecutor: { e } " )
108+ raise e
102109
103110 @staticmethod
104111 def create_actor_weak_ref (actor_handle : ray .actor .ActorHandle ):
@@ -137,12 +144,19 @@ def create_workers(self, worker_cls, worker_kwargs):
137144 for rank in range (self .world_size )
138145 ]
139146
140- ray .get ([worker .__ray_ready__ .remote () for worker in self .workers ])
147+ try :
148+ ray .get ([worker .__ray_ready__ .remote () for worker in self .workers ])
149+ except ray .exceptions .ActorDiedError as e :
150+ if "The actor died because of an error raised in its creation task" in str (
151+ e ):
152+ raise RuntimeError (
153+ "RayGPUWorker died during initialization" ) from e
154+ raise
141155
156+ @unwrap_ray_errors ()
142157 def call_all_ray_workers (self , func : str , leader_only : bool ,
143158 async_call : bool , * args , ** kwargs ):
144159 workers = (self .workers [0 ], ) if leader_only else self .workers
145-
146160 if async_call :
147161 return [
148162 getattr (worker , func ).remote (* args , ** kwargs )
@@ -154,6 +168,7 @@ def call_all_ray_workers(self, func: str, leader_only: bool,
154168 for worker in workers
155169 ])
156170
171+ @unwrap_ray_errors ()
157172 def collective_rpc (self ,
158173 method : str ,
159174 args : tuple = (),
@@ -174,7 +189,6 @@ def collective_rpc(self,
174189 # Ray actor doesn't work with __getattr__ delegation.
175190 refs .append (w .call_worker_method .remote (method , * args ,
176191 ** kwargs ))
177-
178192 return refs if non_block else ray .get (refs )
179193
180194 def submit (self , request : GenerationRequest ) -> GenerationResult :
@@ -224,11 +238,14 @@ def shutdown(self):
224238 self .workers = None
225239 if hasattr (self ,
226240 "placement_group" ) and self .placement_group is not None :
227- ray .util .remove_placement_group (self .placement_group )
241+ # Only remove placement group if Ray is still initialized
242+ # to avoid triggering auto_init_ray() during program exit
243+ if ray .is_initialized ():
244+ ray .util .remove_placement_group (self .placement_group )
228245 self .placement_group = None
229246 self .bundle_indices = None
230247
231- if self .has_start_local_cluser :
248+ if self .has_start_local_cluser and ray . is_initialized () :
232249 logger .debug ("Shutting down Ray cluster" )
233250 ray .shutdown ()
234251
0 commit comments