Skip to content

Commit 91bc17a

Browse files
committed
test and fix for agent
Signed-off-by: Shixiaowei02 <[email protected]>
1 parent 0661079 commit 91bc17a

File tree

3 files changed

+179
-5
lines changed

3 files changed

+179
-5
lines changed

tensorrt_llm/_torch/disaggregation/base/agent.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def check_remote_descs(self, name: str, memory_descs: List[int]) -> bool: ...
8181
@dataclass
8282
class RegMemoryDescs:
8383
type: str
84-
descs: List[Tuple[int, int, int, str]]
84+
descs: List[Tuple[int, int, int, str]] # (ptr, size, device_id, name)
8585

8686

8787
def _force_py_nixl_kv_transfer() -> bool:
@@ -123,4 +123,6 @@ def _try_load_cpp_binding():
123123
BaseTransferAgent = _cpp_binding.BaseTransferAgent
124124
logger.info("Using Pybind transfer agent binding for Transfer Agent implementation.")
125125
else:
126-
logger.info("Failed to import Pybind transfer agent binding, using pure Python implementation.")
126+
logger.warning(
127+
"Failed to import Pybind transfer agent binding, using pure Python implementation."
128+
)

tensorrt_llm/_torch/disaggregation/nixl/_agent_py.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,19 +37,33 @@ def wait(self):
3737
class NixlTransferAgent(BaseTransferAgent):
3838
"""NixlTransferAgent using Python nixl library."""
3939

40-
def __init__(self, name: str, use_prog_thread: bool, num_workers: int = 1):
40+
def __init__(self, name: str, use_prog_thread: bool = True, num_workers: int = 1):
41+
"""
42+
Initialize NixlTransferAgent.
43+
:param name: Name of the agent.
44+
:param use_prog_thread: Whether to enable the progress thread, if available.
45+
:param num_workers: Specify number of threads for the supported multi-threaded backends.
46+
"""
4147
self.name = name
48+
self.backends = ["UCX"]
4249
agent_config = nixl_agent_config(
43-
enable_prog_thread=use_prog_thread, backends=["UCX"], num_threads=num_workers
50+
enable_prog_thread=use_prog_thread, backends=self.backends, num_threads=num_workers
4451
)
4552
self.agent = nixl_agent(name, agent_config)
4653

4754
def register_memory(self, descs: RegMemoryDescs):
55+
if isinstance(descs.descs[0], tuple):
56+
assert len(descs.descs[0]) == 4, f"Expected 4 elements per desc, got {descs.descs[0]}"
4857
reg_descs = self.agent.get_reg_descs(descs.descs, descs.type)
58+
assert reg_descs is not None, "Failed to get reg_descs"
4959
self.agent.register_memory(reg_descs)
5060

5161
def deregister_memory(self, descs: RegMemoryDescs):
52-
self.agent.deregister_memory(descs.descs, descs.type)
62+
if isinstance(descs.descs[0], tuple):
63+
assert len(descs.descs[0]) == 4, f"Expected 4 elements per desc, got {descs.descs[0]}"
64+
reg_descs = self.agent.get_reg_descs(descs.descs, descs.type)
65+
assert reg_descs is not None, "Failed to get reg_descs"
66+
self.agent.deregister_memory(reg_descs)
5367

5468
def load_remote_agent(self, name: str, agent_desc: bytes):
5569
self.agent.add_remote_agent(agent_desc)
@@ -70,6 +84,8 @@ def notify_sync_message(self, name: str, sync_message: str):
7084
def submit_transfer_requests(self, request: TransferRequest) -> TransferStatus:
7185
src_xfer_descs = self.agent.get_xfer_descs(request.src_descs.descs, request.src_descs.type)
7286
dst_xfer_descs = self.agent.get_xfer_descs(request.dst_descs.descs, request.dst_descs.type)
87+
assert src_xfer_descs is not None, "Failed to get src_xfer_descs"
88+
assert dst_xfer_descs is not None, "Failed to get dst_xfer_descs"
7389
sync_message = "" if request.sync_message is None else request.sync_message
7490
handle = self.agent.initialize_xfer(
7591
request.op,
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
from dataclasses import dataclass, field
2+
3+
import pytest
4+
import torch
5+
6+
from tensorrt_llm import logger
7+
from tensorrt_llm._torch.disaggregation.base.agent import (
8+
MemoryDescs,
9+
MemoryType,
10+
RegMemoryDescs,
11+
TransferOp,
12+
TransferRequest,
13+
)
14+
from tensorrt_llm._torch.disaggregation.nixl.agent import NixlTransferAgent
15+
16+
17+
def _convert_to_memory_descs(reg_descs: RegMemoryDescs) -> MemoryDescs:
18+
tuples = [(ptr, size, device_id) for (ptr, size, device_id, _) in reg_descs.descs]
19+
20+
def _convert_memory_type(py_type: str) -> MemoryType:
21+
"""Convert Python memory type string to C++ MemoryType."""
22+
type_map = {
23+
"DRAM": MemoryType.DRAM,
24+
"VRAM": MemoryType.VRAM,
25+
"GPU": MemoryType.VRAM,
26+
"BLK": MemoryType.BLK,
27+
"OBJ": MemoryType.OBJ,
28+
"FILE": MemoryType.FILE,
29+
}
30+
return type_map.get(py_type.upper(), MemoryType.VRAM)
31+
32+
return MemoryDescs(_convert_memory_type(reg_descs.type), tuples)
33+
34+
35+
@dataclass
36+
class MemoryManager:
37+
allocated_memory: list[torch.Tensor] = field(default_factory=list)
38+
39+
def allocate_memory(
40+
self, size: int, name: str, memory_type=MemoryType.VRAM, device_id: int = 0
41+
) -> RegMemoryDescs:
42+
device = torch.device(f"cuda:{device_id}" if memory_type == MemoryType.VRAM else "cpu")
43+
44+
# Allocate memory block using torch.Tensor and track it
45+
block = torch.zeros(size, dtype=torch.uint8, device=device)
46+
self.allocated_memory.append(block)
47+
48+
# Return RegMemoryDescs with position arguments
49+
memory_descs = RegMemoryDescs(
50+
type=memory_type, descs=[(block.data_ptr(), block.numel(), device_id, name)]
51+
)
52+
return memory_descs
53+
54+
def clear_memory(self):
55+
"""Clear all tracked memory blocks."""
56+
self.allocated_memory.clear()
57+
58+
59+
@pytest.fixture
60+
def memory_manager():
61+
return MemoryManager()
62+
63+
64+
@pytest.fixture(params=[256, 512])
65+
def memory_size(request):
66+
return request.param
67+
68+
69+
@pytest.fixture(params=["DRAM", "VRAM"])
70+
def memory_type(request):
71+
return request.param
72+
73+
74+
@pytest.fixture
75+
def alloc(memory_manager, memory_size, memory_type):
76+
"""Allocate memory for source and destination, based on the memory_size and memory_type parameters."""
77+
assert memory_size > 0, "Memory size must be a positive integer."
78+
src_descs = memory_manager.allocate_memory(
79+
size=memory_size, name="src_mem", memory_type=memory_type
80+
)
81+
dst_descs = memory_manager.allocate_memory(
82+
size=memory_size, name="dst_mem", memory_type=memory_type
83+
)
84+
return src_descs, dst_descs
85+
86+
87+
@pytest.fixture
88+
def transfer_agent_src():
89+
return NixlTransferAgent(name="src_agent")
90+
91+
92+
@pytest.fixture
93+
def transfer_agent_dst():
94+
return NixlTransferAgent(name="dst_agent")
95+
96+
97+
def test_transfer_between_agents(
98+
transfer_agent_src,
99+
transfer_agent_dst,
100+
memory_manager,
101+
alloc,
102+
memory_size,
103+
memory_type,
104+
):
105+
"""End-to-end test of data transfer between two agents with parameterized memory sizes and types."""
106+
# Debug log the parameters being tested
107+
logger.info(f"Testing with memory_size={memory_size}, memory_type={memory_type}")
108+
109+
# Unpack source and destination memory descriptions
110+
memory_descs_src, memory_descs_dst = alloc
111+
112+
# Fill source memory with sequential data for validation
113+
src_data = memory_manager.allocated_memory[0]
114+
assert memory_size > 0, "Memory size must be positive."
115+
tensor = torch.arange(memory_size, dtype=torch.uint8) % 10
116+
src_data.copy_(tensor)
117+
118+
# Register memory with source and destination agents
119+
transfer_agent_src.register_memory(memory_descs_src)
120+
transfer_agent_dst.register_memory(memory_descs_dst)
121+
122+
src_agent_desc = transfer_agent_src.get_local_agent_desc()
123+
transfer_agent_dst.load_remote_agent("src_agent", src_agent_desc)
124+
125+
dst_agent_desc = transfer_agent_dst.get_local_agent_desc()
126+
transfer_agent_src.load_remote_agent("dst_agent", dst_agent_desc)
127+
128+
# Create and submit the transfer request
129+
transfer_request = TransferRequest(
130+
op=TransferOp.WRITE,
131+
src_descs=_convert_to_memory_descs(memory_descs_src),
132+
dst_descs=_convert_to_memory_descs(memory_descs_dst),
133+
remote_name="dst_agent",
134+
sync_message=None,
135+
)
136+
transfer_status = transfer_agent_src.submit_transfer_requests(transfer_request)
137+
transfer_status.wait()
138+
139+
# Validate transfer completion
140+
assert transfer_status.is_completed(), "Transfer did not complete successfully."
141+
142+
# Validate that the destination data matches the source data
143+
dst_data = memory_manager.allocated_memory[1]
144+
assert torch.equal(dst_data, src_data), "Destination data does not match source data."
145+
146+
# Clean up by deregistering memory and clearing allocations
147+
transfer_agent_src.deregister_memory(memory_descs_src)
148+
transfer_agent_dst.deregister_memory(memory_descs_dst)
149+
memory_manager.clear_memory()
150+
151+
transfer_agent_src.invalidate_remote_agent("dst_agent")
152+
transfer_agent_dst.invalidate_remote_agent("src_agent")
153+
154+
155+
if __name__ == "__main__":
156+
pytest.main()

0 commit comments

Comments
 (0)