Skip to content

Commit 865c24c

Browse files
committed
[TRTLLM-8922][chore] Remove hang detector from py_cache_transceiver_mp test
Signed-off-by: Lizhi Zhou <1432185+reasonsolo@users.noreply.github.com>
1 parent 747915e commit 865c24c

File tree

11 files changed

+141
-98
lines changed

11 files changed

+141
-98
lines changed

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2728,12 +2728,14 @@ def _check_disagg_gen_transfer_status(self):
27282728
req.is_disagg_generation_transmission_in_progress
27292729
for req in self.active_requests
27302730
])
2731-
need_check_one = all([
2731+
non_gen_first_reqs = [
2732+
req for req in self.active_requests
2733+
if req.py_disaggregated_params and req.py_disaggregated_params.
2734+
schedule_style != DisaggScheduleStyle.GENERATION_FIRST
2735+
]
2736+
need_check_one = bool(non_gen_first_reqs) and all(
27322737
req.is_disagg_generation_transmission_in_progress
2733-
for req in self.active_requests
2734-
if req.py_disaggregated_params \
2735-
and req.py_disaggregated_params.schedule_style != DisaggScheduleStyle.GENERATION_FIRST
2736-
])
2738+
for req in non_gen_first_reqs)
27372739

27382740
if need_check:
27392741
at_least_num = 1 if need_check_one else 0
@@ -2874,7 +2876,6 @@ def _prepare_disagg_gen_transmission_complete(self, scheduled_batch):
28742876
req.decoding_iter = 1
28752877
req.py_decoding_iter = 1
28762878
req.py_kv_transfer_start_time = None
2877-
req.decoding_iter = 1
28782879
first_gen_tokens = req.context_phase_params.first_gen_tokens
28792880
ctx_draft_tokens = req.context_phase_params.draft_tokens
28802881
req.py_draft_tokens = [] if ctx_draft_tokens is None else ctx_draft_tokens
@@ -2944,12 +2945,14 @@ def _recv_disagg_gen_cache(self, new_gen_reqs):
29442945
if req.state == LlmRequestState.DISAGG_GENERATION_TRANS_IN_PROGRESS:
29452946
req.py_kv_transfer_start_time = time.time()
29462947

2947-
block_transfer = all([
2948+
non_gen_first_active = [
2949+
req for req in self.active_requests
2950+
if req.py_disaggregated_params and req.py_disaggregated_params.
2951+
schedule_style != DisaggScheduleStyle.GENERATION_FIRST
2952+
]
2953+
block_transfer = bool(non_gen_first_active) and all(
29482954
req.is_disagg_generation_transmission_in_progress
2949-
and req.py_disaggregated_params.schedule_style
2950-
!= DisaggScheduleStyle.GENERATION_FIRST
2951-
for req in self.active_requests
2952-
])
2955+
for req in non_gen_first_active)
29532956
self._check_disagg_gen_cache_transfer_status(1 if block_transfer else 0)
29542957

29552958
return

tensorrt_llm/commands/serve.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@
2424
from tensorrt_llm.executor.utils import LlmLauncherEnvs
2525
from tensorrt_llm.inputs.multimodal import MultimodalServerConfig
2626
from tensorrt_llm.llmapi import (BuildConfig, CapacitySchedulerPolicy,
27-
DisaggScheduleStyle, DynamicBatchConfig,
28-
KvCacheConfig, SchedulerConfig, VisualGen)
27+
DynamicBatchConfig, KvCacheConfig,
28+
SchedulerConfig, VisualGen)
2929
from tensorrt_llm.llmapi.disagg_utils import (DisaggClusterConfig,
3030
MetadataServerConfig, ServerRole,
3131
extract_disagg_cluster_config,
@@ -1011,7 +1011,8 @@ def serve_encoder(model: str, host: str, port: int, log_level: str,
10111011
help="The logging level.")
10121012
@click.option("-s",
10131013
"--schedule_style",
1014-
type=None,
1014+
type=click.Choice(["context_first", "generation_first"],
1015+
case_sensitive=False),
10151016
default=None,
10161017
help="The schedule style for the disaggregated server.")
10171018
@click.option(
@@ -1041,13 +1042,6 @@ def disaggregated(
10411042

10421043
disagg_cfg = parse_disagg_config_file(config_file)
10431044
if schedule_style:
1044-
valid_styles = [
1045-
key.lower() for key in DisaggScheduleStyle.__members__.keys()
1046-
]
1047-
if schedule_style not in valid_styles:
1048-
raise ValueError(
1049-
f"Invalid schedule style: {schedule_style}, options: {valid_styles}"
1050-
)
10511045
disagg_cfg.schedule_style = schedule_style
10521046
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
10531047
try:

tensorrt_llm/executor/proxy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from .request import CancellingRequest, GenerationRequest
2525
from .result import GenerationResult, IterationResult
2626
from .rpc import RPCClient
27-
from .rpc.rpc_common import get_unique_ipc_addr
27+
from .rpc.rpc_common import RPCError, get_unique_ipc_addr
2828
from .utils import (ErrorResponse, WorkerCommIpcAddrs, create_mpi_comm_session,
2929
get_spawn_proxy_process_env, is_llm_response,
3030
print_alive_threads)
@@ -396,7 +396,7 @@ def get_disaggregated_params(self) -> dict:
396396
try:
397397
params = self.rpc_client.get_disaggregated_params().remote()
398398
return params if isinstance(params, dict) else {}
399-
except Exception as e:
399+
except RPCError as e:
400400
logger.warning(f"Error fetching disaggregated params via RPC: {e}")
401401
return {}
402402

tensorrt_llm/llmapi/disagg_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ def extract_disagg_cfg(hostname: str = 'localhost',
178178
conditional_disagg_config, otlp_config,
179179
max_retries, perf_metrics_max_requests,
180180
disagg_cluster_config)
181-
if node_id:
181+
if node_id is not None:
182182
config.node_id = node_id
183183
if schedule_style:
184184
config.schedule_style = schedule_style

tensorrt_llm/serve/openai_client.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,9 @@ async def _send_request(
127127
if server is None:
128128
server, _ = await self._router.get_next_server(request)
129129
url = f"http://{server}/{endpoint}"
130-
logger.debug(f"Sending {self._role} request {request.disaggregated_params} to {url}")
130+
logger.debug(
131+
f"Sending {self._role} request {request.disaggregated_params.ctx_request_id} to {url}"
132+
)
131133
try:
132134
self._metrics_collector.total_requests.inc()
133135
resp_generator = self._post_with_retry(server, url, request, hooks)

tensorrt_llm/serve/openai_disagg_service.py

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -405,40 +405,43 @@ async def _consume_gen():
405405
await queue.put(e)
406406
await queue.put(None) # sentinel
407407

408-
asyncio.create_task(_consume_gen())
408+
consume_task: asyncio.Task = asyncio.create_task(_consume_gen())
409409

410410
# Now send ctx request — gen server has received its request
411411
await self._ctx_client.send_request(ctx_req, server=ctx_server, hooks=hooks)
412412

413413
async def _yield_from_queue():
414-
while True:
415-
item = await queue.get()
416-
if item is None:
417-
break
418-
if isinstance(item, Exception):
419-
raise item
420-
yield item
414+
try:
415+
while True:
416+
item = await queue.get()
417+
if item is None:
418+
break
419+
if isinstance(item, Exception):
420+
raise item
421+
yield item
422+
finally:
423+
if not consume_task.done():
424+
consume_task.cancel()
425+
try:
426+
await consume_task
427+
except asyncio.CancelledError:
428+
pass
421429

422430
return _yield_from_queue()
423431
else:
424432
# Non-streaming or no ctx needed: both HTTP POSTs fire eagerly
425433
# through generator consumption, so asyncio.gather works fine.
426434
tasks = []
427435
if need_ctx:
428-
async def request_ctx():
429-
response = await self._ctx_client.send_request(
430-
ctx_req, server=ctx_server, hooks=hooks
436+
tasks.append(
437+
asyncio.create_task(
438+
self._ctx_client.send_request(ctx_req, server=ctx_server, hooks=hooks)
431439
)
432-
return response
433-
434-
tasks.append(asyncio.create_task(request_ctx()))
435-
436-
async def request_gen():
437-
response = await self._gen_client.send_request(
438-
gen_req, server=gen_server, hooks=hooks
439440
)
440-
return response
441-
442-
tasks.append(asyncio.create_task(request_gen()))
441+
tasks.append(
442+
asyncio.create_task(
443+
self._gen_client.send_request(gen_req, server=gen_server, hooks=hooks)
444+
)
445+
)
443446
responses = await asyncio.gather(*tasks)
444447
return responses[-1]

tests/integration/defs/accuracy/test_disaggregated_serving.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1616,12 +1616,13 @@ def test_auto_dtype_with_helix(self, comms_medium, cuda_graph_config,
16161616
@pytest.mark.parametrize(
16171617
"gen_tp_pp", [(1, 1), (1, 2), (2, 1), (2, 2)],
16181618
ids=["gen_tp1pp1", "gen_tp1pp2", "gen_tp2pp1", "gen_tp2pp2"])
1619-
@pytest.mark.parametrize("ctx_tp_pp", [(1, 1), (1, 2), (2, 1), (2, 2),
1620-
(1, 4)],
1621-
ids=[
1622-
"ctx_tp1pp1", "ctx_tp1pp2", "ctx_tp2pp1",
1623-
"ctx_tp2pp2", "ctx_tp1pp4"
1624-
])
1619+
@pytest.mark.parametrize(
1620+
"ctx_tp_pp",
1621+
[(1, 1), (1, 2), (2, 1), (2, 2), (1, 4)],
1622+
ids=[
1623+
"ctx_tp1pp1", "ctx_tp1pp2", "ctx_tp2pp1", "ctx_tp2pp2", "ctx_tp1pp4"
1624+
],
1625+
)
16251626
def test_gen_first(self, ctx_tp_pp, gen_tp_pp):
16261627
ctx_tp, ctx_pp = ctx_tp_pp
16271628
gen_tp, gen_pp = gen_tp_pp
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
model: TinyLlama/TinyLlama-1.1B-Chat-v1.0
2+
hostname: localhost
3+
port: 8000
4+
backend: "pytorch"
5+
cuda_graph_config: null
6+
free_gpu_memory_fraction: 0.2
7+
context_servers:
8+
num_instances: 1
9+
max_batch_size: 8
10+
max_num_tokens: 3000
11+
max_seq_len: 4096
12+
tensor_parallel_size: 1
13+
pipeline_parallel_size: 1
14+
kv_cache_config:
15+
enable_block_reuse: False
16+
free_gpu_memory_fraction: 0.2
17+
enable_partial_reuse: False
18+
19+
cache_transceiver_config:
20+
backend: DEFAULT
21+
transceiver_runtime: PYTHON
22+
urls:
23+
- "localhost:8001"
24+
generation_servers:
25+
num_instances: 1
26+
tensor_parallel_size: 1
27+
pipeline_parallel_size: 1
28+
max_batch_size: 256
29+
max_num_tokens: 4096
30+
max_seq_len: 4096
31+
kv_cache_config:
32+
enable_block_reuse: False
33+
free_gpu_memory_fraction: 0.2
34+
enable_partial_reuse: False
35+
cache_transceiver_config:
36+
backend: DEFAULT
37+
transceiver_runtime: PYTHON
38+
urls:
39+
- "localhost:8002"
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
model: TinyLlama/TinyLlama-1.1B-Chat-v1.0
2+
hostname: localhost
3+
port: 8000
4+
backend: "pytorch"
5+
cuda_graph_config: null
6+
free_gpu_memory_fraction: 0.2
7+
context_servers:
8+
num_instances: 1
9+
max_batch_size: 8
10+
max_num_tokens: 3000
11+
max_seq_len: 4096
12+
tensor_parallel_size: 1
13+
pipeline_parallel_size: 4
14+
kv_cache_config:
15+
enable_block_reuse: False
16+
free_gpu_memory_fraction: 0.2
17+
enable_partial_reuse: False
18+
19+
cache_transceiver_config:
20+
backend: DEFAULT
21+
transceiver_runtime: PYTHON
22+
urls:
23+
- "localhost:8001"
24+
generation_servers:
25+
num_instances: 1
26+
tensor_parallel_size: 1
27+
pipeline_parallel_size: 1
28+
max_batch_size: 256
29+
max_num_tokens: 4096
30+
max_seq_len: 4096
31+
kv_cache_config:
32+
enable_block_reuse: False
33+
free_gpu_memory_fraction: 0.2
34+
enable_partial_reuse: False
35+
cache_transceiver_config:
36+
backend: DEFAULT
37+
transceiver_runtime: PYTHON
38+
urls:
39+
- "localhost:8002"

tests/integration/defs/disaggregated/test_disaggregated.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,7 @@ def get_client_test_set(test_desc):
284284
verify_streaming_completion=True,
285285
verify_chat=False,
286286
verify_streaming_chat=False)
287-
if test_desc in ("overlap", "trtllm_sampler"):
287+
if test_desc.startswith("overlap") or test_desc == "trtllm_sampler":
288288
return ClientTestSet(completion=True,
289289
completion_streaming=True,
290290
chat=True,

0 commit comments

Comments
 (0)