Skip to content

Commit 3205aab

Browse files
committed
cleanup
1 parent d32df78 commit 3205aab

File tree

1 file changed

+17
-22
lines changed

1 file changed

+17
-22
lines changed

tensorrt_llm/executor/ray_executor.py

Lines changed: 17 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
import os
21
import asyncio
2+
import os
33
from typing import Any, Dict, List, Optional, Tuple
44

55
try:
@@ -80,16 +80,16 @@ def __init__(self,
8080
self.master_port = get_free_port()
8181
self.use_rpc = ray_use_rpc()
8282

83-
self.worker_kwargs = dict(**worker_kwargs,
84-
postproc_worker_config=postproc_worker_config,
85-
is_llm_executor=is_llm_executor)
86-
if not has_event_loop():
87-
self.init_workers_sync()
83+
self.worker_kwargs = dict(
84+
**worker_kwargs,
85+
postproc_worker_config=postproc_worker_config,
86+
is_llm_executor=is_llm_executor)
8887

8988
if self.use_rpc:
9089
self.init_rpc_executor()
91-
worker_kwargs['rpc_addr'] = self.rpc_addr
92-
self.create_workers(RayGPUWorker, worker_kwargs)
90+
self.worker_kwargs['rpc_addr'] = self.rpc_addr
91+
if not has_event_loop():
92+
self.init_workers_sync()
9393
self.setup_engine_remote()
9494
self.setup_mainloop(tasks=[self._fetch_responses_loop_async],
9595
thread_name="ray_executor_main_loop")
@@ -111,7 +111,8 @@ def __init__(self,
111111
self.response_sync_queue)
112112
self.response_queue.warmup.remote()
113113
self.response_sync_queue.warmup.remote()
114-
self.create_workers(RayGPUWorker, worker_kwargs)
114+
if not has_event_loop():
115+
self.init_workers_sync()
115116

116117
except Exception as e:
117118
self.shutdown()
@@ -149,25 +150,16 @@ def create_workers(self, worker_cls, worker_kwargs):
149150
def init_workers_sync(self):
150151
self.create_workers(RayGPUWorker, self.worker_kwargs)
151152
try:
152-
ray.get([worker.__ray_ready__.remote() for worker in self.workers])
153+
ray.get(self._get_worker_ready_futures())
153154
except ray.exceptions.ActorDiedError as e:
154-
if "The actor died because of an error raised in its creation task" in str(
155-
e):
156-
raise RuntimeError(
157-
"RayGPUWorker died during initialization") from e
158-
raise
155+
raise RuntimeError("RayGPUWorker died during initialization") from e
159156

160157
async def init_workers_async(self):
161158
self.create_workers(RayGPUWorker, self.worker_kwargs)
162159
try:
163-
await asyncio.gather(*[worker.__ray_ready__.remote() for worker in self.workers])
160+
await asyncio.gather(*self._get_worker_ready_futures())
164161
except ray.exceptions.ActorDiedError as e:
165-
if "The actor died because of an error raised in its creation task" in str(
166-
e):
167-
raise RuntimeError(
168-
"RayGPUWorker died during initialization") from e
169-
raise
170-
162+
raise RuntimeError("RayGPUWorker died during initialization") from e
171163

172164
@unwrap_ray_errors()
173165
def call_all_ray_workers(self, func: str, leader_only: bool,
@@ -334,6 +326,9 @@ def shutdown(self):
334326
logger.debug("Shutting down Ray cluster")
335327
ray.shutdown()
336328

329+
def _get_worker_ready_futures(self):
330+
return [worker.__ray_ready__.remote() for worker in self.workers]
331+
337332
def _get_placement_group(self,
338333
tp_size: int) -> Tuple[PlacementGroup, List[int]]:
339334
"""

0 commit comments

Comments
 (0)