Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 22 additions & 7 deletions tensorrt_llm/_torch/disaggregation/transceiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def _broadcast_context_endpoint(self) -> str:
def _init_sync_policy(self):
m = self._mapping
self._ctx_need_tp_sync = m.tp_size > 1 and not m.enable_attention_dp
self._ctx_need_pp_sync = m.pp_size > 1
self._gen_need_sync = not (m.world_size == 1 or (m.enable_attention_dp and m.pp_size == 1))
pp_allgather: Callable = getattr(self._dist, "pp_allgather")
self._gen_allgather: Callable = (
Expand Down Expand Up @@ -183,9 +184,25 @@ def _need_aux_transfer(req: LlmRequest) -> bool:
return params is not None and params.schedule_style == DisaggScheduleStyle.GENERATION_FIRST

def _ctx_consensus(self, local_ids: list) -> list:
# TP consensus: ensure all TP ranks have peer info
sync_size = self._dist.tp_size if self._ctx_need_tp_sync else 1
all_ranks = self._dist.tp_allgather(local_ids) if self._ctx_need_tp_sync else [local_ids]
return _find_consensus_request_ids(all_ranks, sync_size)
ready_ids = _find_consensus_request_ids(all_ranks, sync_size)

# PP consensus: ensure all PP ranks have peer info before promoting.
# In PP, the first PP rank schedules and propagates to others. If a
# request is promoted on the first rank but peer info hasn't arrived
# on other ranks, respond_and_send_async on those ranks would fail
# to dispatch the KV transfer (gen-first skips listener dispatch).
# TODO: This is a workaround for functionality: pp_allgather impacts
# the pp loop performance. One possible solution is to let pp rank0
# decide the ready request ids, the other pp ranks treat the unready
# request as ctx-first requests.
if self._ctx_need_pp_sync:
pp_all_ranks = getattr(self._dist, "pp_allgather")(ready_ids)
ready_ids = _find_consensus_request_ids(pp_all_ranks, self._mapping.pp_size)

return ready_ids

def _gen_consensus(self, local_ids: list) -> list:
sync_size = (
Expand Down Expand Up @@ -365,22 +382,20 @@ def get_disaggregated_params(self) -> Dict[str, Any]:
# requests before context-phase response data arrives.
return {
"ctx_dp_rank": self._dp_rank,
"ctx_info_endpoint": [self._context_info_endpoint]
if self._context_info_endpoint
else None,
"ctx_info_endpoint": self._context_info_endpoint or None,
}

def prepare_context_requests(self, requests: List[LlmRequest]):
# Place new generation-first context requests into wait state, then
# use tp_allgather consensus to promote ready requests to CONTEXT_INIT.
# use allgather consensus to promote ready requests to CONTEXT_INIT.
for req in requests:
rid = get_unique_rid(req)
if rid not in self._send_sessions:
self._wait_reqs[rid] = req
req.state = LlmRequestState.DISAGG_CONTEXT_WAIT_SCHEDULER

# Check which waiting requests have peer info locally, then tp_allgather
# consensus so all TP ranks agree before promoting.
# Check which waiting requests have peer info locally, then allgather
# consensus so all TP/PP ranks agree before promoting.
# Without consensus, background peer info arriving at different times on
# different ranks causes scheduling mismatches → hang.
local_ready = [
Expand Down
3 changes: 2 additions & 1 deletion tensorrt_llm/_torch/pyexecutor/kv_cache_transceiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,8 @@ def cancel_request(self, req: LlmRequest):
return self.impl.cancel_request(req)

def prepare_context_requests(self, requests: List[LlmRequest]):
raise NotImplementedError
# not implemented, an empty placeholder to allow being invoked unconditionally
...

def get_disaggregated_params(self):
# Cpp kv cache transceiver will set the disaggregated params to context response
Expand Down
97 changes: 78 additions & 19 deletions tensorrt_llm/_torch/pyexecutor/py_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,19 @@ def _maybe_init_kv_connector_manager(self):
self.kv_connector_manager.layer_post_hook)

def _end_transfer_and_maybe_terminate(self, request: LlmRequest):
if self.kv_cache_transceiver and request in self.active_requests:
# Fast-transfer: KV transfer completed in the same iteration
# before _handle_responses could run. Create the response now
# while state is still TRANS_IN_PROGRESS (required by C++
# createResult). Then proceed with end_transfer + termination.
response = request.create_response(False, self.dist.rank)
if response:
response.result.cached_tokens = request.cached_tokens
self._enqueue_responses([(request.py_request_id, response)])
if self.async_transfer_manager.end_transfer(request):
self.active_requests.remove(request)
self._terminate_request(request)
return
if self.async_transfer_manager.end_transfer(request):
self._terminate_request(request)

Expand Down Expand Up @@ -1276,6 +1289,7 @@ def _executor_loop_pp(self):
self._handle_control_request()

if self.kv_cache_transceiver:
self._check_disagg_ctx_schedulable_status(new_requests)
self._check_disagg_gen_transfer_status()

if self.enable_iter_perf_stats:
Expand All @@ -1300,11 +1314,22 @@ def _executor_loop_pp(self):
self._prepare_disagg_gen_init(
fitting_disagg_gen_init_requests)

all_gen_first = self.active_requests and all(
req.py_disaggregated_params
and req.py_disaggregated_params.schedule_style ==
DisaggScheduleStyle.GENERATION_FIRST
for req in self.active_requests)
if num_fitting_reqs == 0 and not fitting_disagg_gen_init_requests:
logger.warning(
"num_fitting_reqs=0 and fitting_disagg_gen_init_requests is empty, may not have enough kvCache"
)
self._check_disagg_ctx_cache_transfer_status(1)
if not all_gen_first:
logger.warning(
"num_fitting_reqs=0 and fitting_disagg_gen_init_requests is empty, may not have enough kvCache"
)
self._check_disagg_ctx_cache_transfer_status(1)
elif self.async_transfer_manager.has_any_inflight_requests(
):
# Non-blocking cleanup of completed/timed-out
# transfers to free KV blocks (see _executor_loop).
self._check_disagg_ctx_cache_transfer_status(0)

self.num_scheduled_requests = scheduled_batch.batch_size

Expand Down Expand Up @@ -1600,13 +1625,19 @@ def _handle_executed_batch(self, executed_batch: Optional[BatchStatePP]):
self._handle_canceled_requests()

finished_requests = self._handle_responses()
# Complete ctx send sessions AFTER responses are created so
# _handle_responses sees the request before it is terminated.
if self.kv_cache_transceiver:
self._check_disagg_ctx_cache_transfer_status(0)
sample_state_scheduled_requests = executed_batch.scheduled_requests
attn_metadata = getattr(self.model_engine, 'attn_metadata',
None)
kv_cache_dtype_byte_size = getattr(self.model_engine,
'kv_cache_dtype_byte_size',
None)
self.resource_manager.update_resources(
scheduled_requests, attn_metadata, kv_cache_dtype_byte_size)
sample_state_scheduled_requests, attn_metadata,
kv_cache_dtype_byte_size)

self._remove_inflight_ids(scheduled_requests)

Expand Down Expand Up @@ -1791,11 +1822,23 @@ def _prepare_and_schedule_batch(self):
# For requests that are fitting disagg gen init, also prepare resources for KV cache manager
self._prepare_disagg_gen_init(fitting_disagg_gen_init_requests)

all_gen_first = self.active_requests and all(
req.py_disaggregated_params and req.py_disaggregated_params.
schedule_style == DisaggScheduleStyle.GENERATION_FIRST
for req in self.active_requests)
if num_fitting_reqs == 0 and not fitting_disagg_gen_init_requests:
logger.warning(
"num_fitting_reqs=0 and fitting_disagg_gen_init_requests is empty, may not have enough kvCache"
)
self._check_disagg_ctx_cache_transfer_status(1)
if not all_gen_first:
logger.warning(
"num_fitting_reqs=0 and fitting_disagg_gen_init_requests is empty, may not have enough kvCache"
)
self._check_disagg_ctx_cache_transfer_status(1)
elif self.async_transfer_manager.has_any_inflight_requests():
# Non-blocking cleanup of completed/timed-out transfers
# to free KV blocks. We avoid the blocking check because
# gen-first requests may be waiting for peer info (which
# would block indefinitely), but completed transfers must
# still be reaped so that KV cache can be reclaimed.
self._check_disagg_ctx_cache_transfer_status(0)

# In gen-only benchmark mode, all requests must fit in KV cache
# simultaneously. If some requests are stuck in INIT state and the
Expand Down Expand Up @@ -1880,6 +1923,7 @@ def _executor_loop(self):
finished_requests = []

can_queue, _ = self._can_queue(scheduled_batch)

if can_queue:
if self.kv_cache_transceiver:
# For generation requests which have completed KV cache transfer
Expand Down Expand Up @@ -1977,6 +2021,10 @@ def _executor_loop(self):

self._handle_canceled_requests()
finished_requests = self._handle_responses()
# Complete ctx send sessions AFTER responses are created so
# _handle_responses sees the request before it is terminated.
if self.kv_cache_transceiver:
self._check_disagg_ctx_cache_transfer_status(0)
# Compute GPU times after _handle_responses creates metric entries
# (safe in non-overlap mode: no next iteration to overwrite events)
self.perf_manager.compute_batch_gpu_times(
Expand Down Expand Up @@ -2135,6 +2183,7 @@ def _executor_loop_overlap(self):

can_queue, can_queue_this_rank = self._can_queue(
scheduled_batch)

if can_queue:
if self.kv_cache_transceiver:
# For generation requests which have completed KV cache transfer
Expand Down Expand Up @@ -2771,10 +2820,14 @@ def _check_disagg_gen_transfer_status(self):
req.is_disagg_generation_transmission_in_progress
for req in self.active_requests
])
need_check_one = all([
non_gen_first_reqs = [
req for req in self.active_requests
if req.py_disaggregated_params and req.py_disaggregated_params.
schedule_style != DisaggScheduleStyle.GENERATION_FIRST
]
need_check_one = bool(non_gen_first_reqs) and all(
req.is_disagg_generation_transmission_in_progress
for req in self.active_requests
])
for req in non_gen_first_reqs)

if need_check:
at_least_num = 1 if need_check_one else 0
Expand Down Expand Up @@ -2819,14 +2872,16 @@ def _check_disagg_ctx_schedulable_status(self,
"""
if not self.kv_cache_transceiver:
return
ctx_only_requests = [
gen_first_ctx_requests = [
req for req in new_requests
if req.is_context_only_request and req.py_disaggregated_params.
schedule_style == DisaggScheduleStyle.GENERATION_FIRST
]
if ctx_only_requests:
self.kv_cache_transceiver.prepare_context_requests(
ctx_only_requests)
# Always call prepare_context_requests when there are new requests
# or previously-waiting requests, so the tp_allgather consensus
# can promote requests whose peer info has arrived on all ranks.
self.kv_cache_transceiver.prepare_context_requests(
gen_first_ctx_requests)

@nvtx_range("_pad_attention_dp_dummy_request")
def _pad_attention_dp_dummy_request(self):
Expand Down Expand Up @@ -2981,10 +3036,14 @@ def _recv_disagg_gen_cache(self, new_gen_reqs):
if req.state == LlmRequestState.DISAGG_GENERATION_TRANS_IN_PROGRESS:
req.py_kv_transfer_start_time = time.time()

block_transfer = all([
non_gen_first_active = [
req for req in self.active_requests
if req.py_disaggregated_params and req.py_disaggregated_params.
schedule_style != DisaggScheduleStyle.GENERATION_FIRST
]
block_transfer = bool(non_gen_first_active) and all(
req.is_disagg_generation_transmission_in_progress
for req in self.active_requests
])
for req in non_gen_first_active)
self._check_disagg_gen_cache_transfer_status(1 if block_transfer else 0)

return
Expand Down
10 changes: 9 additions & 1 deletion tensorrt_llm/commands/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -1102,6 +1102,12 @@ def serve_encoder(model: str, host: str, port: int, log_level: str,
type=click.Choice(severity_map.keys()),
default='info',
help="The logging level.")
@click.option("-s",
"--schedule_style",
type=click.Choice(["context_first", "generation_first"],
case_sensitive=False),
default=None,
help="The schedule style for the disaggregated server.")
@click.option(
"--metrics-log-interval",
type=int,
Expand All @@ -1116,6 +1122,7 @@ def disaggregated(
request_timeout: int,
log_level: str,
metrics_log_interval: int,
schedule_style: str,
):
"""Running server in disaggregated mode"""

Expand All @@ -1130,7 +1137,8 @@ def disaggregated(
logger.warning("--config_file is deprecated, use --config instead.")

disagg_cfg = parse_disagg_config_file(config_file)

if schedule_style:
disagg_cfg.schedule_style = schedule_style
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
try:
s.bind((disagg_cfg.hostname, disagg_cfg.port))
Expand Down
4 changes: 3 additions & 1 deletion tensorrt_llm/disaggregated_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,10 @@ def get_context_phase_params(self) -> tllme.ContextPhaseParams:
request_id = (
self.disagg_request_id if self.disagg_request_id is not None else self.ctx_request_id
)
# `first_gen_tokens` is now required by bindings and cannot be None.
first_gen_tokens = self.first_gen_tokens if self.first_gen_tokens is not None else []
return tllme.ContextPhaseParams(
self.first_gen_tokens,
first_gen_tokens,
request_id,
self.opaque_state,
self.draft_tokens,
Expand Down
15 changes: 14 additions & 1 deletion tensorrt_llm/executor/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from .request import CancellingRequest, GenerationRequest
from .result import GenerationResult, IterationResult
from .rpc import RPCClient
from .rpc.rpc_common import get_unique_ipc_addr
from .rpc.rpc_common import RPCError, get_unique_ipc_addr
from .utils import (ErrorResponse, WorkerCommIpcAddrs, create_mpi_comm_session,
get_spawn_proxy_process_env, is_llm_response,
print_alive_threads)
Expand Down Expand Up @@ -387,6 +387,19 @@ def get_stats(self, timeout: float) -> List[dict]:
stats = self.rpc_client.fetch_stats_wait_async(timeout=timeout).remote()
return [json.loads(s) if isinstance(s, str) else s for s in stats]

def get_disaggregated_params(self) -> dict:
"""Get disaggregated params from worker runtime via RPC."""
if self.rpc_client is None:
logger.warning(
"RPC client not initialized, cannot get disaggregated params")
return {}
try:
params = self.rpc_client.get_disaggregated_params().remote()
return params if isinstance(params, dict) else {}
except RPCError as e:
logger.warning(f"Error fetching disaggregated params via RPC: {e}")
return {}

def aget_stats(self, timeout: float) -> IterationResult:
"""Get iteration statistics from the runtime via RPC (async).

Expand Down
9 changes: 8 additions & 1 deletion tensorrt_llm/llmapi/disagg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,10 @@ def extract_disagg_cfg(hostname: str = 'localhost',
conditional_disagg_config: Optional[dict] = None,
otlp_config: Optional[dict] = None,
disagg_cluster: Optional[dict] = None,
node_id: Optional[int] = None,
schedule_style: Literal[
'context_first',
'generation_first'] = 'context_first',
**kwargs: Any) -> DisaggServerConfig:
context_servers = context_servers or {}
generation_servers = generation_servers or {}
Expand Down Expand Up @@ -174,7 +178,10 @@ def extract_disagg_cfg(hostname: str = 'localhost',
conditional_disagg_config, otlp_config,
max_retries, perf_metrics_max_requests,
disagg_cluster_config)

if node_id is not None:
config.node_id = node_id
if schedule_style:
config.schedule_style = schedule_style
return config


Expand Down
2 changes: 1 addition & 1 deletion tensorrt_llm/llmapi/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def __init__(self,
self._executor_cls = kwargs.pop("executor_cls", GenerationExecutor)
self._orchestrator_type = kwargs.get("orchestrator_type", None)
self._llm_id = None
self._disaggregated_params = {}
self._disaggregated_params: Optional[dict] = None

log_level = logger.level
logger.set_level("info") # force display the backend
Expand Down
9 changes: 7 additions & 2 deletions tensorrt_llm/serve/openai_client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2025, NVIDIA CORPORATION.
# Copyright (c) 2026, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -288,7 +288,12 @@ async def shutdown(self) -> None:
await self._session.close()

async def check_ready(self) -> Tuple[List[str], List[str]]:
return await OpenAIHttpClient.check_ready_for_servers(self._session, self._router.servers)
ready_servers, unready_servers = await OpenAIHttpClient.check_ready_for_servers(
self._session, self._router.servers
)
if ready_servers:
await self._router.prepare_servers(ready_servers)
return ready_servers, unready_servers

@staticmethod
async def check_ready_for_servers(
Expand Down
Loading
Loading