Skip to content

Commit 9d6634b

Browse files
shuyixiongdominicshanshan
authored andcommitted
[TRTLLM-8507][fix] Fix ray resource cleanup and error handling in LoRA test (NVIDIA#8175)
Signed-off-by: shuyix <[email protected]>
1 parent 8000827 commit 9d6634b

File tree

6 files changed

+110
-57
lines changed

6 files changed

+110
-57
lines changed

tensorrt_llm/_ray_utils.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
from contextlib import contextmanager
16+
17+
try:
18+
import ray
19+
except ImportError:
20+
import tensorrt_llm.ray_stub as ray
21+
22+
23+
@contextmanager
24+
def unwrap_ray_errors():
25+
try:
26+
yield
27+
except ray.exceptions.RayTaskError as e:
28+
raise e.as_instanceof_cause() from e

tensorrt_llm/executor/ray_executor.py

Lines changed: 62 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
get_current_placement_group,
1313
placement_group)
1414

15+
from tensorrt_llm._ray_utils import unwrap_ray_errors
1516
from tensorrt_llm._utils import get_free_port
1617
from tensorrt_llm.logger import logger
1718

@@ -57,48 +58,54 @@ def __init__(self,
5758
"runtime_env": runtime_env
5859
}
5960

60-
if os.environ.get("TLLM_RAY_FORCE_LOCAL_CLUSTER", "0") != "1":
61-
try:
62-
ray.init(address="auto", **ray_init_args)
63-
logger.info(f"Attached to an existing Ray cluster.")
64-
except ConnectionError:
65-
logger.info(f"Ray cluster not found, starting a new one.")
66-
67-
if not ray.is_initialized():
68-
ray.init(**ray_init_args)
61+
try:
62+
if os.environ.get("TLLM_RAY_FORCE_LOCAL_CLUSTER", "0") != "1":
63+
try:
64+
ray.init(address="auto", **ray_init_args)
65+
logger.info(f"Attached to an existing Ray cluster.")
66+
except ConnectionError:
67+
logger.info(f"Ray cluster not found, starting a new one.")
68+
69+
if not ray.is_initialized():
70+
ray.init(**ray_init_args)
71+
self.has_start_local_cluser = True
72+
else:
73+
ray.init(address="local", **ray_init_args)
6974
self.has_start_local_cluser = True
70-
else:
71-
ray.init(address="local", **ray_init_args)
72-
self.has_start_local_cluser = True
7375

74-
self.world_size = model_world_size
75-
self.tp_size = tp_size
76-
self.master_address = ray.util.get_node_ip_address()
77-
self.master_port = get_free_port()
78-
79-
self.response_queue = RayAsyncQueue.options(runtime_env={
80-
"env_vars": {
81-
"TLLM_DISABLE_MPI": "1"
82-
}
83-
}).remote()
84-
self.response_sync_queue = RaySyncQueue.options(runtime_env={
85-
"env_vars": {
86-
"TLLM_DISABLE_MPI": "1"
87-
}
88-
}).remote()
89-
self.async_response_queue_weakref = self.create_actor_weak_ref(
90-
self.response_queue)
91-
self.sync_response_queue_weakref = self.create_actor_weak_ref(
92-
self.response_sync_queue)
93-
self.response_queue.warmup.remote()
94-
self.response_sync_queue.warmup.remote()
95-
96-
worker_kwargs = dict(**worker_kwargs,
97-
postproc_worker_config=postproc_worker_config,
98-
is_llm_executor=is_llm_executor,
99-
kv_connector_config=kv_connector_config)
100-
101-
self.create_workers(RayGPUWorker, worker_kwargs)
76+
self.world_size = model_world_size
77+
self.tp_size = tp_size
78+
self.master_address = ray.util.get_node_ip_address()
79+
self.master_port = get_free_port()
80+
81+
self.response_queue = RayAsyncQueue.options(runtime_env={
82+
"env_vars": {
83+
"TLLM_DISABLE_MPI": "1"
84+
}
85+
}).remote()
86+
self.response_sync_queue = RaySyncQueue.options(runtime_env={
87+
"env_vars": {
88+
"TLLM_DISABLE_MPI": "1"
89+
}
90+
}).remote()
91+
self.async_response_queue_weakref = self.create_actor_weak_ref(
92+
self.response_queue)
93+
self.sync_response_queue_weakref = self.create_actor_weak_ref(
94+
self.response_sync_queue)
95+
self.response_queue.warmup.remote()
96+
self.response_sync_queue.warmup.remote()
97+
98+
worker_kwargs = dict(**worker_kwargs,
99+
postproc_worker_config=postproc_worker_config,
100+
is_llm_executor=is_llm_executor,
101+
kv_connector_config=kv_connector_config)
102+
103+
self.create_workers(RayGPUWorker, worker_kwargs)
104+
except Exception as e:
105+
# Clean up the Ray resources early during exception
106+
self.shutdown()
107+
logger.error(f"Failed to initialize RayExecutor: {e}")
108+
raise e
102109

103110
@staticmethod
104111
def create_actor_weak_ref(actor_handle: ray.actor.ActorHandle):
@@ -137,12 +144,19 @@ def create_workers(self, worker_cls, worker_kwargs):
137144
for rank in range(self.world_size)
138145
]
139146

140-
ray.get([worker.__ray_ready__.remote() for worker in self.workers])
147+
try:
148+
ray.get([worker.__ray_ready__.remote() for worker in self.workers])
149+
except ray.exceptions.ActorDiedError as e:
150+
if "The actor died because of an error raised in its creation task" in str(
151+
e):
152+
raise RuntimeError(
153+
"RayGPUWorker died during initialization") from e
154+
raise
141155

156+
@unwrap_ray_errors()
142157
def call_all_ray_workers(self, func: str, leader_only: bool,
143158
async_call: bool, *args, **kwargs):
144159
workers = (self.workers[0], ) if leader_only else self.workers
145-
146160
if async_call:
147161
return [
148162
getattr(worker, func).remote(*args, **kwargs)
@@ -154,6 +168,7 @@ def call_all_ray_workers(self, func: str, leader_only: bool,
154168
for worker in workers
155169
])
156170

171+
@unwrap_ray_errors()
157172
def collective_rpc(self,
158173
method: str,
159174
args: tuple = (),
@@ -174,7 +189,6 @@ def collective_rpc(self,
174189
# Ray actor doesn't work with __getattr__ delegation.
175190
refs.append(w.call_worker_method.remote(method, *args,
176191
**kwargs))
177-
178192
return refs if non_block else ray.get(refs)
179193

180194
def submit(self, request: GenerationRequest) -> GenerationResult:
@@ -224,11 +238,14 @@ def shutdown(self):
224238
self.workers = None
225239
if hasattr(self,
226240
"placement_group") and self.placement_group is not None:
227-
ray.util.remove_placement_group(self.placement_group)
241+
# Only remove placement group if Ray is still initialized
242+
# to avoid triggering auto_init_ray() during program exit
243+
if ray.is_initialized():
244+
ray.util.remove_placement_group(self.placement_group)
228245
self.placement_group = None
229246
self.bundle_indices = None
230247

231-
if self.has_start_local_cluser:
248+
if self.has_start_local_cluser and ray.is_initialized():
232249
logger.debug("Shutting down Ray cluster")
233250
ray.shutdown()
234251

tensorrt_llm/executor/result.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
except ModuleNotFoundError:
1717
from tensorrt_llm import ray_stub as ray
1818

19+
from .._ray_utils import unwrap_ray_errors
1920
from .._utils import mpi_disabled, nvtx_range_debug
2021
from ..bindings import executor as tllm
2122
from ..disaggregated_params import DisaggregatedParams
@@ -274,8 +275,8 @@ def __init__(self,
274275
else:
275276
self.queue = ray_queue
276277
self.aqueue = None
277-
278-
ray.get(self.queue.register.remote(id))
278+
with unwrap_ray_errors():
279+
ray.get(self.queue.register.remote(id))
279280
else:
280281
if has_event_loop():
281282
self.aqueue = AsyncQueue()
@@ -735,7 +736,8 @@ def _handle_ray_response(self, response: Any):
735736

736737
def _result_step(self, timeout: Optional[float] = None):
737738
if mpi_disabled():
738-
response = ray.get(self.queue.get.remote(self.request_id))
739+
with unwrap_ray_errors():
740+
response = ray.get(self.queue.get.remote(self.request_id))
739741
response = self._handle_ray_response(response)
740742
else:
741743
response = self.queue.get()

tensorrt_llm/ray_stub.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,11 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15-
import functools
15+
from functools import wraps as _wraps
1616

17-
from tensorrt_llm._utils import mpi_disabled
17+
from tensorrt_llm._utils import mpi_disabled as _mpi_disabled
1818

19-
if mpi_disabled():
19+
if _mpi_disabled():
2020
raise RuntimeError(
2121
"Ray requested (TLLM_DISABLE_MPI=1), but not installed. Please install Ray."
2222
)
@@ -27,14 +27,21 @@ def remote(*args, **kwargs):
2727
def decorator(func):
2828
# Returns a function that always raises.
2929
# Decorated class depends on ray, but ray is not installed.
30-
@functools.wraps(func)
30+
@_wraps(func)
3131
def stub_checker(*_, **__):
3232
raise RuntimeError(
33-
"Ray not installed, cannot use Ray based feature.")
33+
f'Ray not installed, so the remote function / actor "{func.__name__}" is not available.'
34+
)
3435

3536
return stub_checker
3637

3738
if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
3839
return decorator(args[0])
3940

4041
return decorator
42+
43+
44+
def __getattr__(name):
45+
raise RuntimeError(
46+
f'Ray not installed, so "ray.{name}" is unavailable. Please install Ray.'
47+
)

tests/unittest/llmapi/test_llm_pytorch.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -414,7 +414,6 @@ def test_llama_7b_multi_lora_evict_and_reload_evicted_adapters_in_cpu_and_gpu_ca
414414
repeats_per_call=1)
415415

416416

417-
@skip_ray
418417
@skip_gpu_memory_less_than_40gb
419418
def test_llama_7b_peft_cache_config_affects_peft_cache_size():
420419
"""Tests that LLM arg of peft_cache_config affects the peft cache sizes.

tests/unittest/utils/util.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919
from parameterized import parameterized
2020

2121
import tensorrt_llm
22-
from tensorrt_llm._utils import torch_dtype_to_trt, trt_dtype_to_torch
22+
from tensorrt_llm._utils import (mpi_disabled, torch_dtype_to_trt,
23+
trt_dtype_to_torch)
2324
from tensorrt_llm.llmapi.utils import get_total_gpu_memory
2425
from tensorrt_llm.plugin.plugin import ContextFMHAType
2526
from tensorrt_llm.quantization import QuantMode
@@ -449,5 +450,4 @@ def check_accuracy(a, b, atol, rtol, percent):
449450

450451

451452
skip_ray = pytest.mark.skipif(
452-
os.environ.get("TLLM_DISABLE_MPI") == "1",
453-
reason="This test is skipped for Ray orchestrator.")
453+
mpi_disabled(), reason="This test is skipped for Ray orchestrator.")

0 commit comments

Comments
 (0)