Skip to content

Commit 5cca4a0

Browse files
committed
refactor the agent
Signed-off-by: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>
1 parent cd618b9 commit 5cca4a0

File tree

5 files changed

+271
-359
lines changed

5 files changed

+271
-359
lines changed

tensorrt_llm/_torch/disaggregation/base/agent.py

Lines changed: 72 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -2,126 +2,77 @@
22
from dataclasses import dataclass
33
from typing import Enum, List, Tuple, Union
44

5-
from tensorrt_llm.logger import logger
6-
7-
# Try to import C++ bindings for zero-copy performance
8-
try:
9-
from tensorrt_llm.tensorrt_llm_transfer_agent_binding import (
10-
BaseTransferAgent,
11-
MemoryDesc,
12-
MemoryDescs,
13-
MemoryType,
14-
TransferOp,
15-
TransferRequest,
16-
TransferStatus,
17-
)
18-
19-
_CPP_BINDING_AVAILABLE = True
20-
except ImportError:
21-
_CPP_BINDING_AVAILABLE = False
22-
logger.warning(
23-
"C++ transfer agent bindings not available. "
24-
"Falling back to Python implementations which may have lower performance."
25-
)
26-
27-
28-
def is_cpp_binding_available() -> bool:
29-
"""Check if C++ transfer agent bindings are available."""
30-
return _CPP_BINDING_AVAILABLE
31-
32-
33-
# Fallback Python implementations when C++ bindings not available
34-
if not _CPP_BINDING_AVAILABLE:
35-
36-
class TransferOp(Enum):
37-
READ = "READ"
38-
WRITE = "WRITE"
39-
40-
class MemoryType(Enum):
41-
DRAM = "DRAM"
42-
VRAM = "VRAM"
43-
BLK = "BLK"
44-
OBJ = "OBJ"
45-
FILE = "FILE"
46-
47-
@dataclass
48-
class MemoryDesc:
49-
ptr: int
50-
size: int
51-
device_id: int
52-
53-
@dataclass
54-
class MemoryDescs:
55-
type: str
56-
descs: List[Union[Tuple[int, int, int], MemoryDesc]]
57-
58-
@dataclass
59-
class TransferRequest:
60-
op: TransferOp
61-
src_descs: MemoryDescs
62-
dst_descs: MemoryDescs
63-
remote_name: str
64-
sync_message: str
65-
66-
class TransferStatus(ABC):
67-
@abstractmethod
68-
def is_completed(self) -> bool: ...
69-
70-
@abstractmethod
71-
def wait(self, timeout: float | None = None) -> None: ...
72-
73-
class BaseTransferAgent(ABC):
74-
@abstractmethod
75-
def register_memory(self, descs: MemoryDescs) -> None:
76-
"""Register a set of memory descriptors on the agent."""
77-
...
78-
79-
@abstractmethod
80-
def deregister_memory(self, descs: MemoryDescs) -> None:
81-
"""De-register a set of memory descriptors on the agent."""
82-
...
83-
84-
@abstractmethod
85-
def load_remote_agent(self, name: str, agent_desc: str) -> None:
86-
"""
87-
Load information about a remote agent specified by name.
88-
89-
Args:
90-
name (str): The remote agent's identifier.
91-
agent_desc (str): A serialized description of the agent.
92-
"""
93-
...
94-
95-
@abstractmethod
96-
def get_local_agent_desc(self) -> str:
97-
"""Return the serialized description of this agent."""
98-
...
99-
100-
@abstractmethod
101-
def invalidate_remote_agent(self, name: str) -> None:
102-
"""Invalidate any cached information about the specified remote agent."""
103-
...
104-
105-
@abstractmethod
106-
def submit_transfer_requests(self, request: TransferRequest) -> TransferStatus:
107-
"""Submit transfer tasks to the agent based on a request."""
108-
...
109-
110-
@abstractmethod
111-
def notify_sync_message(self, name: str, sync_message: str) -> None:
112-
"""Send a synchronization message to the specified remote agent."""
113-
...
114-
115-
@abstractmethod
116-
def check_remote_descs(self, name: str, memory_descs: List[int]) -> bool:
117-
"""
118-
Verify the remote agent's memory descriptors.
119-
"""
120-
...
121-
122-
123-
# RegMemoryDescs is Python-only (used for registration with name field)
5+
6+
# Common Enumerations
7+
class TransferOp(Enum):
8+
READ = "READ"
9+
WRITE = "WRITE"
10+
11+
12+
class MemoryType(Enum):
13+
DRAM = "DRAM"
14+
VRAM = "VRAM"
15+
BLK = "BLK"
16+
OBJ = "OBJ"
17+
FILE = "FILE"
18+
19+
20+
# Common Data Structures
12421
@dataclass
125-
class RegMemoryDescs:
22+
class MemoryDesc:
23+
ptr: int
24+
size: int
25+
device_id: int
26+
27+
28+
@dataclass
29+
class MemoryDescs:
12630
type: str
127-
descs: List[Tuple[int, int, int, str]]
31+
descs: List[Union[Tuple[int, int, int], MemoryDesc]]
32+
33+
34+
@dataclass
35+
class TransferRequest:
36+
op: TransferOp
37+
src_descs: MemoryDescs
38+
dst_descs: MemoryDescs
39+
remote_name: str
40+
sync_message: str
41+
42+
43+
class BaseTransferStatus(ABC):
44+
"""Abstract base class for transfer status."""
45+
46+
@abstractmethod
47+
def is_completed(self) -> bool: ...
48+
49+
@abstractmethod
50+
def wait(self, timeout: float | None = None) -> None: ...
51+
52+
53+
class BaseTransferAgent(ABC):
54+
"""Abstract base class for transfer agents."""
55+
56+
@abstractmethod
57+
def register_memory(self, descs: MemoryDescs) -> None: ...
58+
59+
@abstractmethod
60+
def deregister_memory(self, descs: MemoryDescs) -> None: ...
61+
62+
@abstractmethod
63+
def load_remote_agent(self, name: str, agent_desc: str) -> None: ...
64+
65+
@abstractmethod
66+
def get_local_agent_desc(self) -> str: ...
67+
68+
@abstractmethod
69+
def invalidate_remote_agent(self, name: str) -> None: ...
70+
71+
@abstractmethod
72+
def submit_transfer_requests(self, request: TransferRequest) -> BaseTransferStatus: ...
73+
74+
@abstractmethod
75+
def notify_sync_message(self, name: str, sync_message: str) -> None: ...
76+
77+
@abstractmethod
78+
def check_remote_descs(self, name: str, memory_descs: List[int]) -> bool: ...

tensorrt_llm/_torch/disaggregation/native/messenger.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def listener():
150150
poller.register(self._control_socket, zmq.POLLIN)
151151

152152
while not self._stop_event.is_set():
153-
events = dict(poller.poll())
153+
events = dict(poller.poll(timeout=100))
154154
try:
155155
if self._control_socket in events:
156156
self._stop_event.set()
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
from tensorrt_llm._utils import nvtx_range
2+
from tensorrt_llm.tensorrt_llm_transfer_agent_binding import (
3+
AgentDesc,
4+
BaseAgentConfig,
5+
MemoryDescs,
6+
MemoryType,
7+
TransferState,
8+
)
9+
from tensorrt_llm.tensorrt_llm_transfer_agent_binding import (
10+
NixlTransferAgent as CppNixlTransferAgent,
11+
)
12+
from tensorrt_llm.tensorrt_llm_transfer_agent_binding import (
13+
NixlTransferStatus as CppNixlTransferStatus,
14+
)
15+
16+
from ..base.agent import BaseTransferAgent, BaseTransferStatus, RegMemoryDescs, TransferRequest
17+
18+
19+
class NixlTransferStatus(BaseTransferStatus):
20+
def __init__(self, cpp_status: CppNixlTransferStatus):
21+
self._cpp_status = cpp_status
22+
23+
def is_completed(self) -> bool:
24+
"""Check if transfer is completed (releases GIL)."""
25+
return self._cpp_status.is_completed()
26+
27+
@nvtx_range("NixlTransferStatus.wait")
28+
def wait(self, timeout: float = None) -> bool:
29+
"""Wait for transfer to complete (releases GIL)."""
30+
return self._cpp_status.wait() == TransferState.SUCCESS
31+
32+
33+
class NixlTransferAgent(BaseTransferAgent):
34+
"""NixlTransferAgent using C++ bindings with GIL release support.
35+
36+
This implementation uses the standalone nixl_bindings C++ module which releases
37+
the GIL during blocking operations like wait().
38+
39+
The nixl_bindings module is independent from the main trtllm bindings,
40+
so trtllm can function normally even without NIXL.
41+
"""
42+
43+
def __init__(self, name: str, use_prog_thread: bool = True, num_workers: int = 1):
44+
config = BaseAgentConfig(
45+
name=name,
46+
use_prog_thread=use_prog_thread,
47+
multi_thread=False,
48+
use_listen_thread=False,
49+
num_workers=num_workers,
50+
)
51+
self._cpp_agent = CppNixlTransferAgent(config)
52+
self.name = name
53+
54+
def register_memory(self, descs: RegMemoryDescs):
55+
cpp_descs = self._convert_reg_memory_descs(descs)
56+
self._cpp_agent.register_memory(cpp_descs)
57+
58+
def deregister_memory(self, descs: RegMemoryDescs):
59+
cpp_descs = self._convert_reg_memory_descs(descs)
60+
self._cpp_agent.deregister_memory(cpp_descs)
61+
62+
def load_remote_agent(self, name: str, agent_desc: bytes):
63+
desc_str = agent_desc if isinstance(agent_desc, bytes) else agent_desc.encode()
64+
cpp_desc = AgentDesc(desc_str)
65+
self._cpp_agent.load_remote_agent(name, cpp_desc)
66+
67+
def load_remote_agent_by_connection(self, name: str, connection_info: str):
68+
self._cpp_agent.load_remote_agent_by_connection(name, connection_info)
69+
70+
def get_local_agent_desc(self) -> bytes:
71+
agent_desc = self._cpp_agent.get_local_agent_desc()
72+
return agent_desc.backend_agent_desc
73+
74+
def get_local_connection_info(self) -> str:
75+
return self._cpp_agent.get_local_connection_info()
76+
77+
def invalidate_remote_agent(self, name: str):
78+
self._cpp_agent.invalidate_remote_agent(name)
79+
80+
def check_remote_descs(self, name: str, memory_descs: MemoryDescs) -> bool:
81+
return self._cpp_agent.check_remote_descs(name, memory_descs)
82+
83+
def notify_sync_message(self, name: str, sync_message: str):
84+
self._cpp_agent.notify_sync_message(name, sync_message)
85+
86+
def get_notified_sync_messages(self):
87+
return self._cpp_agent.get_notified_sync_messages()
88+
89+
@nvtx_range("BindingsNixlTransferAgent.submit_transfer_requests")
90+
def submit_transfer_requests(self, request: TransferRequest) -> BaseTransferStatus:
91+
cpp_status = self._cpp_agent.submit_transfer_requests(request)
92+
return NixlTransferStatus(cpp_status)
93+
94+
def _convert_reg_memory_descs(self, descs: RegMemoryDescs) -> "MemoryDescs":
95+
mem_type = self._convert_memory_type(descs.type)
96+
tuples = [(d[0], d[1], d[2]) for d in descs.descs] # Extract (ptr, size, device_id)
97+
return MemoryDescs(mem_type, tuples)
98+
99+
def _convert_memory_type(self, py_type: str) -> "MemoryType":
100+
type_map = {
101+
"DRAM": MemoryType.DRAM,
102+
"VRAM": MemoryType.VRAM,
103+
"GPU": MemoryType.VRAM,
104+
"BLK": MemoryType.BLK,
105+
"OBJ": MemoryType.OBJ,
106+
"FILE": MemoryType.FILE,
107+
}
108+
return type_map.get(py_type.upper(), MemoryType.VRAM)
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
import time
2+
3+
from nixl import nixl_agent, nixl_agent_config, nixl_xfer_handle
4+
5+
from tensorrt_llm._utils import nvtx_range
6+
7+
from ..base.agent import BaseTransferAgent, BaseTransferStatus, RegMemoryDescs, TransferRequest
8+
9+
10+
class NixlTransferStatus(BaseTransferStatus):
11+
"""TransferStatus using the Python NIXL library."""
12+
13+
def __init__(self, agent: nixl_agent, handle: nixl_xfer_handle):
14+
self.agent = agent
15+
self.handle = handle
16+
17+
def is_completed(self) -> bool:
18+
status = self.agent.check_xfer_state(self.handle)
19+
return status == "DONE"
20+
21+
def wait(self) -> bool:
22+
status = "PROC"
23+
sleep_time = 0.0001 # 0.1ms
24+
max_sleep_time = 0.01 # 10ms
25+
while status == "PROC":
26+
status = self.agent.check_xfer_state(self.handle)
27+
if status == "ERROR":
28+
return False # Transfer failed
29+
time.sleep(sleep_time) # Sleep to release GIL
30+
sleep_time = min(sleep_time * 2, max_sleep_time)
31+
return status == "DONE"
32+
33+
34+
class NixlTransferAgent(BaseTransferAgent):
35+
"""Python-based TransferAgent using the NIXL library."""
36+
37+
def __init__(self, name: str, use_prog_thread: bool, num_workers: int = 1):
38+
self.name = name
39+
agent_config = nixl_agent_config(
40+
enable_prog_thread=use_prog_thread,
41+
backends=["UCX"],
42+
num_threads=num_workers,
43+
)
44+
self.agent = nixl_agent(name, agent_config)
45+
46+
def register_memory(self, descs: RegMemoryDescs):
47+
reg_descs = self.agent.get_reg_descs(descs.descs, descs.type)
48+
self.agent.register_memory(reg_descs)
49+
50+
def deregister_memory(self, descs: RegMemoryDescs):
51+
self.agent.deregister_memory(descs.descs, descs.type)
52+
53+
def load_remote_agent(self, name: str, agent_desc: bytes):
54+
self.agent.add_remote_agent(agent_desc)
55+
56+
def get_local_agent_desc(self) -> bytes:
57+
return self.agent.get_agent_metadata()
58+
59+
def invalidate_remote_agent(self, name: str):
60+
self.agent.remove_remote_agent(name)
61+
62+
def check_remote_descs(self, name: str, memory_descs: list[int]) -> bool:
63+
raise NotImplementedError("check_remote_descs is not implemented.")
64+
65+
def notify_sync_message(self, name: str, sync_message: str):
66+
raise NotImplementedError("notify_sync_message is not implemented.")
67+
68+
@nvtx_range("NixlTransferAgent.submit_transfer_requests")
69+
def submit_transfer_requests(self, request: TransferRequest) -> BaseTransferStatus:
70+
src_xfer_descs = self.agent.get_xfer_descs(request.src_descs.descs, request.src_descs.type)
71+
dst_xfer_descs = self.agent.get_xfer_descs(request.dst_descs.descs, request.dst_descs.type)
72+
handle = self.agent.initialize_xfer(
73+
request.op,
74+
src_xfer_descs,
75+
dst_xfer_descs,
76+
request.remote_name,
77+
request.sync_message,
78+
)
79+
status = self.agent.transfer(handle)
80+
assert status != "ERROR", "Transfer failed during initialization."
81+
return NixlTransferStatus(self.agent, handle)

0 commit comments

Comments
 (0)