Skip to content

Commit 40ddad8

Browse files
committed
[TRTLLM-8922][fix] Address review comments on gen-first disagg PR
- Fix all([]) returning True when no non-gen-first requests exist - Catch RPCError instead of broad Exception in proxy - Log only ctx_request_id instead of full disaggregated_params - Cancel background consume task when stream generator stops early - Fix parametrize indentation lint in test - Include overlap_gen_first* in overlap client test validation Signed-off-by: Lizhi Zhou <1432185+reasonsolo@users.noreply.github.com>
1 parent c956aa1 commit 40ddad8

File tree

6 files changed

+43
-28
lines changed

6 files changed

+43
-28
lines changed

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2728,12 +2728,14 @@ def _check_disagg_gen_transfer_status(self):
27282728
req.is_disagg_generation_transmission_in_progress
27292729
for req in self.active_requests
27302730
])
2731-
need_check_one = all([
2731+
non_gen_first_reqs = [
2732+
req for req in self.active_requests
2733+
if req.py_disaggregated_params and req.py_disaggregated_params.
2734+
schedule_style != DisaggScheduleStyle.GENERATION_FIRST
2735+
]
2736+
need_check_one = bool(non_gen_first_reqs) and all(
27322737
req.is_disagg_generation_transmission_in_progress
2733-
for req in self.active_requests
2734-
if req.py_disaggregated_params \
2735-
and req.py_disaggregated_params.schedule_style != DisaggScheduleStyle.GENERATION_FIRST
2736-
])
2738+
for req in non_gen_first_reqs)
27372739

27382740
if need_check:
27392741
at_least_num = 1 if need_check_one else 0
@@ -2943,12 +2945,14 @@ def _recv_disagg_gen_cache(self, new_gen_reqs):
29432945
if req.state == LlmRequestState.DISAGG_GENERATION_TRANS_IN_PROGRESS:
29442946
req.py_kv_transfer_start_time = time.time()
29452947

2946-
block_transfer = all([
2948+
non_gen_first_active = [
2949+
req for req in self.active_requests
2950+
if req.py_disaggregated_params and req.py_disaggregated_params.
2951+
schedule_style != DisaggScheduleStyle.GENERATION_FIRST
2952+
]
2953+
block_transfer = bool(non_gen_first_active) and all(
29472954
req.is_disagg_generation_transmission_in_progress
2948-
and req.py_disaggregated_params.schedule_style
2949-
!= DisaggScheduleStyle.GENERATION_FIRST
2950-
for req in self.active_requests
2951-
])
2955+
for req in non_gen_first_active)
29522956
self._check_disagg_gen_cache_transfer_status(1 if block_transfer else 0)
29532957

29542958
return

tensorrt_llm/executor/proxy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from .request import CancellingRequest, GenerationRequest
2525
from .result import GenerationResult, IterationResult
2626
from .rpc import RPCClient
27-
from .rpc.rpc_common import get_unique_ipc_addr
27+
from .rpc.rpc_common import RPCError, get_unique_ipc_addr
2828
from .utils import (ErrorResponse, WorkerCommIpcAddrs, create_mpi_comm_session,
2929
get_spawn_proxy_process_env, is_llm_response,
3030
print_alive_threads)
@@ -396,7 +396,7 @@ def get_disaggregated_params(self) -> dict:
396396
try:
397397
params = self.rpc_client.get_disaggregated_params().remote()
398398
return params if isinstance(params, dict) else {}
399-
except Exception as e:
399+
except RPCError as e:
400400
logger.warning(f"Error fetching disaggregated params via RPC: {e}")
401401
return {}
402402

tensorrt_llm/serve/openai_client.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,9 @@ async def _send_request(
127127
if server is None:
128128
server, _ = await self._router.get_next_server(request)
129129
url = f"http://{server}/{endpoint}"
130-
logger.debug(f"Sending {self._role} request {request.disaggregated_params} to {url}")
130+
logger.debug(
131+
f"Sending {self._role} request {request.disaggregated_params.ctx_request_id} to {url}"
132+
)
131133
try:
132134
self._metrics_collector.total_requests.inc()
133135
resp_generator = self._post_with_retry(server, url, request, hooks)

tensorrt_llm/serve/openai_disagg_service.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -405,19 +405,27 @@ async def _consume_gen():
405405
await queue.put(e)
406406
await queue.put(None) # sentinel
407407

408-
consume_task = asyncio.create_task(_consume_gen()) # noqa: F841 prevent GC
408+
consume_task: asyncio.Task = asyncio.create_task(_consume_gen())
409409

410410
# Now send ctx request — gen server has received its request
411411
await self._ctx_client.send_request(ctx_req, server=ctx_server, hooks=hooks)
412412

413413
async def _yield_from_queue():
414-
while True:
415-
item = await queue.get()
416-
if item is None:
417-
break
418-
if isinstance(item, Exception):
419-
raise item
420-
yield item
414+
try:
415+
while True:
416+
item = await queue.get()
417+
if item is None:
418+
break
419+
if isinstance(item, Exception):
420+
raise item
421+
yield item
422+
finally:
423+
if not consume_task.done():
424+
consume_task.cancel()
425+
try:
426+
await consume_task
427+
except asyncio.CancelledError:
428+
pass
421429

422430
return _yield_from_queue()
423431
else:

tests/integration/defs/accuracy/test_disaggregated_serving.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1616,12 +1616,13 @@ def test_auto_dtype_with_helix(self, comms_medium, cuda_graph_config,
16161616
@pytest.mark.parametrize(
16171617
"gen_tp_pp", [(1, 1), (1, 2), (2, 1), (2, 2)],
16181618
ids=["gen_tp1pp1", "gen_tp1pp2", "gen_tp2pp1", "gen_tp2pp2"])
1619-
@pytest.mark.parametrize("ctx_tp_pp", [(1, 1), (1, 2), (2, 1), (2, 2),
1620-
(1, 4)],
1621-
ids=[
1622-
"ctx_tp1pp1", "ctx_tp1pp2", "ctx_tp2pp1",
1623-
"ctx_tp2pp2", "ctx_tp1pp4"
1624-
])
1619+
@pytest.mark.parametrize(
1620+
"ctx_tp_pp",
1621+
[(1, 1), (1, 2), (2, 1), (2, 2), (1, 4)],
1622+
ids=[
1623+
"ctx_tp1pp1", "ctx_tp1pp2", "ctx_tp2pp1", "ctx_tp2pp2", "ctx_tp1pp4"
1624+
],
1625+
)
16251626
def test_gen_first(self, ctx_tp_pp, gen_tp_pp):
16261627
ctx_tp, ctx_pp = ctx_tp_pp
16271628
gen_tp, gen_pp = gen_tp_pp

tests/integration/defs/disaggregated/test_disaggregated.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,7 @@ def get_client_test_set(test_desc):
284284
verify_streaming_completion=True,
285285
verify_chat=False,
286286
verify_streaming_chat=False)
287-
if test_desc in ("overlap", "trtllm_sampler"):
287+
if test_desc.startswith("overlap") or test_desc == "trtllm_sampler":
288288
return ClientTestSet(completion=True,
289289
completion_streaming=True,
290290
chat=True,

0 commit comments

Comments
 (0)