Skip to content

Commit cd618b9

Browse files
committed
update tests
Signed-off-by: Shixiaowei02 <[email protected]>
1 parent ad69bd0 commit cd618b9

File tree

4 files changed

+320
-79
lines changed

4 files changed

+320
-79
lines changed

tensorrt_llm/_torch/disaggregation/base/agent.py

Lines changed: 50 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,28 @@
11
from abc import ABC, abstractmethod
22
from dataclasses import dataclass
3-
from typing import List, Tuple, Union
3+
from typing import Enum, List, Tuple, Union
4+
5+
from tensorrt_llm.logger import logger
46

57
# Try to import C++ bindings for zero-copy performance
6-
_CPP_BINDING_AVAILABLE = False
78
try:
8-
import tensorrt_llm.tensorrt_llm_transfer_agent_binding as _cpp_binding
9+
from tensorrt_llm.tensorrt_llm_transfer_agent_binding import (
10+
BaseTransferAgent,
11+
MemoryDesc,
12+
MemoryDescs,
13+
MemoryType,
14+
TransferOp,
15+
TransferRequest,
16+
TransferStatus,
17+
)
918

1019
_CPP_BINDING_AVAILABLE = True
11-
# Use C++ types directly when available
12-
MemoryType = _cpp_binding.MemoryType
13-
TransferOp = _cpp_binding.TransferOp
14-
MemoryDesc = _cpp_binding.MemoryDesc
15-
MemoryDescs = _cpp_binding.MemoryDescs
16-
TransferRequest = _cpp_binding.TransferRequest
17-
TransferStatus = _cpp_binding.TransferStatus
18-
BaseTransferAgent = _cpp_binding.BaseTransferAgent
1920
except ImportError:
2021
_CPP_BINDING_AVAILABLE = False
22+
logger.warning(
23+
"C++ transfer agent bindings not available. "
24+
"Falling back to Python implementations which may have lower performance."
25+
)
2126

2227

2328
def is_cpp_binding_available() -> bool:
@@ -28,11 +33,11 @@ def is_cpp_binding_available() -> bool:
2833
# Fallback Python implementations when C++ bindings not available
2934
if not _CPP_BINDING_AVAILABLE:
3035

31-
class TransferOp:
36+
class TransferOp(Enum):
3237
READ = "READ"
3338
WRITE = "WRITE"
3439

35-
class MemoryType:
40+
class MemoryType(Enum):
3641
DRAM = "DRAM"
3742
VRAM = "VRAM"
3843
BLK = "BLK"
@@ -67,28 +72,52 @@ def wait(self, timeout: float | None = None) -> None: ...
6772

6873
class BaseTransferAgent(ABC):
6974
@abstractmethod
70-
def register_memory(self, descs: MemoryDescs) -> None: ...
75+
def register_memory(self, descs: MemoryDescs) -> None:
76+
"""Register a set of memory descriptors on the agent."""
77+
...
7178

7279
@abstractmethod
73-
def deregister_memory(self, descs: MemoryDescs) -> None: ...
80+
def deregister_memory(self, descs: MemoryDescs) -> None:
81+
"""De-register a set of memory descriptors on the agent."""
82+
...
7483

7584
@abstractmethod
76-
def load_remote_agent(self, name: str, agent_desc: str) -> None: ...
85+
def load_remote_agent(self, name: str, agent_desc: str) -> None:
86+
"""
87+
Load information about a remote agent specified by name.
88+
89+
Args:
90+
name (str): The remote agent's identifier.
91+
agent_desc (str): A serialized description of the agent.
92+
"""
93+
...
7794

7895
@abstractmethod
79-
def get_local_agent_desc(self) -> str: ...
96+
def get_local_agent_desc(self) -> str:
97+
"""Return the serialized description of this agent."""
98+
...
8099

81100
@abstractmethod
82-
def invalidate_remote_agent(self, name: str) -> None: ...
101+
def invalidate_remote_agent(self, name: str) -> None:
102+
"""Invalidate any cached information about the specified remote agent."""
103+
...
83104

84105
@abstractmethod
85-
def submit_transfer_requests(self, request: TransferRequest) -> TransferStatus: ...
106+
def submit_transfer_requests(self, request: TransferRequest) -> TransferStatus:
107+
"""Submit transfer tasks to the agent based on a request."""
108+
...
86109

87110
@abstractmethod
88-
def notify_sync_message(self, name: str, sync_message: str) -> None: ...
111+
def notify_sync_message(self, name: str, sync_message: str) -> None:
112+
"""Send a synchronization message to the specified remote agent."""
113+
...
89114

90115
@abstractmethod
91-
def check_remote_descs(self, name: str, memory_descs: List[int]) -> bool: ...
116+
def check_remote_descs(self, name: str, memory_descs: List[int]) -> bool:
117+
"""
118+
Verify the remote agent's memory descriptors.
119+
"""
120+
...
92121

93122

94123
# RegMemoryDescs is Python-only (used for registration with name field)

tensorrt_llm/_torch/disaggregation/base/kv_transfer.py

Lines changed: 65 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ class KVSlice:
2121

2222

2323
class SessionStatus(Enum):
24+
"""Status of a transfer session."""
25+
2426
INIT = "INIT"
2527
READY = "READY"
2628
TRANSFERRING = "TRANSFERRING"
@@ -36,63 +38,107 @@ class SessionStatus(Enum):
3638

3739
@dataclass
3840
class SessionState:
41+
"""State of a transfer session."""
42+
3943
status: SessionStatus
4044
finished_tasks: List[TaskIdType]
4145

4246

4347
@dataclass
4448
class SessionArgsBase:
45-
request_id: int
49+
"""Base arguments for transfer sessions."""
50+
4651
params: DisaggregatedParams
4752

4853

49-
class SenderBase(ABC): ...
54+
class SenderBase(ABC):
55+
"""Base class for sending KV cache data."""
56+
57+
...
5058

5159

52-
class ReceiverBase(ABC): ...
60+
class ReceiverBase(ABC):
61+
"""Base class for receiving KV cache data."""
62+
63+
...
5364

5465

5566
class TxSessionBase(ABC):
5667
def __init__(self, sender: SenderBase, args: SessionArgsBase):
68+
"""
69+
Initializes the transmission session.
70+
:param sender: The sender instance responsible for sending data.
71+
:param args: The session arguments.
72+
"""
5773
self._base_args = args
5874

5975
@property
6076
@abstractmethod
61-
def state(self) -> SessionState: ...
77+
def state(self) -> SessionState:
78+
"""
79+
Returns the current state of the session.
80+
"""
81+
...
6282

6383
@abstractmethod
64-
def poll_task(self, id: TaskIdType) -> SessionStatus: ...
84+
def poll_task(self, id: TaskIdType) -> SessionStatus:
85+
"""
86+
Polls the status of a specific task by its ID.
87+
:param id: The task ID to poll.
88+
"""
89+
...
6590

6691
@abstractmethod
67-
def send(self, slice: KVSlice) -> TaskIdType: ...
68-
69-
"""
70-
Async send slice to the peer. return the task id. Task state can be polled by poll_task().
71-
"""
92+
def send(self, slice: KVSlice) -> TaskIdType:
93+
"""
94+
Sends a slice of KV cache data and returns the task ID.
95+
:param slice: The KV slice to send.
96+
"""
97+
...
7298

7399
@property
74100
@abstractmethod
75-
def exception(self) -> Optional[Exception]: ...
101+
def exception(self) -> Optional[Exception]:
102+
"""
103+
Returns any exception that occurred during the session.
104+
"""
105+
...
76106

77107

78108
class RxSessionBase(ABC):
79109
def __init__(self, receiver: ReceiverBase, args: SessionArgsBase):
110+
"""
111+
Initializes the reception session.
112+
:param receiver: The receiver instance responsible for receiving data.
113+
"""
80114
self._base_args = args
81115

82116
@property
83117
@abstractmethod
84-
def state(self) -> SessionState: ...
118+
def state(self) -> SessionState:
119+
"""
120+
Returns the current state of the session.
121+
"""
122+
...
85123

86124
@abstractmethod
87-
def poll_task(self, id: TaskIdType) -> SessionStatus: ...
125+
def poll_task(self, id: TaskIdType) -> SessionStatus:
126+
"""
127+
Polls the status of a specific task by its ID.
128+
:param id: The task ID to poll.
129+
"""
130+
...
88131

89132
@abstractmethod
90-
def receive(self, slice: KVSlice) -> TaskIdType: ...
91-
92-
"""
93-
Async receive slice from the peer. return the task id. Task state can be polled by poll_task().
94-
"""
133+
def receive(self, slice: KVSlice) -> TaskIdType:
134+
"""
135+
Receives a slice of KV cache data and returns the task ID.
136+
:param slice: The KV slice to receive.
137+
"""
138+
...
95139

96140
@property
97141
@abstractmethod
98-
def exception(self) -> Optional[Exception]: ...
142+
def exception(self) -> Optional[Exception]:
143+
"""Returns any exception that occurred during the session."""
144+
...

0 commit comments

Comments
 (0)