From ad69bd03cf75760d978742bee5af1a653b610116 Mon Sep 17 00:00:00 2001 From: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com> Date: Wed, 7 Jan 2026 09:01:11 +0000 Subject: [PATCH 1/3] Python transceiver components Signed-off-by: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com> --- .../_torch/disaggregation/__init__.py | 0 .../_torch/disaggregation/base/__init__.py | 0 .../_torch/disaggregation/base/agent.py | 98 +++++++ .../_torch/disaggregation/base/kv_transfer.py | 98 +++++++ .../_torch/disaggregation/native/__init__.py | 0 .../_torch/disaggregation/native/messenger.py | 189 +++++++++++++ .../_torch/disaggregation/native/utils.py | 34 +++ .../_torch/disaggregation/nixl/__init__.py | 0 .../_torch/disaggregation/nixl/agent.py | 253 ++++++++++++++++++ 9 files changed, 672 insertions(+) create mode 100644 tensorrt_llm/_torch/disaggregation/__init__.py create mode 100644 tensorrt_llm/_torch/disaggregation/base/__init__.py create mode 100644 tensorrt_llm/_torch/disaggregation/base/agent.py create mode 100644 tensorrt_llm/_torch/disaggregation/base/kv_transfer.py create mode 100644 tensorrt_llm/_torch/disaggregation/native/__init__.py create mode 100644 tensorrt_llm/_torch/disaggregation/native/messenger.py create mode 100644 tensorrt_llm/_torch/disaggregation/native/utils.py create mode 100644 tensorrt_llm/_torch/disaggregation/nixl/__init__.py create mode 100644 tensorrt_llm/_torch/disaggregation/nixl/agent.py diff --git a/tensorrt_llm/_torch/disaggregation/__init__.py b/tensorrt_llm/_torch/disaggregation/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tensorrt_llm/_torch/disaggregation/base/__init__.py b/tensorrt_llm/_torch/disaggregation/base/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tensorrt_llm/_torch/disaggregation/base/agent.py b/tensorrt_llm/_torch/disaggregation/base/agent.py new file mode 100644 index 00000000000..b677f24a489 --- /dev/null +++ b/tensorrt_llm/_torch/disaggregation/base/agent.py @@ -0,0 +1,98 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import List, Tuple, Union + +# Try to import C++ bindings for zero-copy performance +_CPP_BINDING_AVAILABLE = False +try: + import tensorrt_llm.tensorrt_llm_transfer_agent_binding as _cpp_binding + + _CPP_BINDING_AVAILABLE = True + # Use C++ types directly when available + MemoryType = _cpp_binding.MemoryType + TransferOp = _cpp_binding.TransferOp + MemoryDesc = _cpp_binding.MemoryDesc + MemoryDescs = _cpp_binding.MemoryDescs + TransferRequest = _cpp_binding.TransferRequest + TransferStatus = _cpp_binding.TransferStatus + BaseTransferAgent = _cpp_binding.BaseTransferAgent +except ImportError: + _CPP_BINDING_AVAILABLE = False + + +def is_cpp_binding_available() -> bool: + """Check if C++ transfer agent bindings are available.""" + return _CPP_BINDING_AVAILABLE + + +# Fallback Python implementations when C++ bindings not available +if not _CPP_BINDING_AVAILABLE: + + class TransferOp: + READ = "READ" + WRITE = "WRITE" + + class MemoryType: + DRAM = "DRAM" + VRAM = "VRAM" + BLK = "BLK" + OBJ = "OBJ" + FILE = "FILE" + + @dataclass + class MemoryDesc: + ptr: int + size: int + device_id: int + + @dataclass + class MemoryDescs: + type: str + descs: List[Union[Tuple[int, int, int], MemoryDesc]] + + @dataclass + class TransferRequest: + op: TransferOp + src_descs: MemoryDescs + dst_descs: MemoryDescs + remote_name: str + sync_message: str + + class TransferStatus(ABC): + @abstractmethod + def is_completed(self) -> bool: ... + + @abstractmethod + def wait(self, timeout: float | None = None) -> None: ... + + class BaseTransferAgent(ABC): + @abstractmethod + def register_memory(self, descs: MemoryDescs) -> None: ... + + @abstractmethod + def deregister_memory(self, descs: MemoryDescs) -> None: ... + + @abstractmethod + def load_remote_agent(self, name: str, agent_desc: str) -> None: ... + + @abstractmethod + def get_local_agent_desc(self) -> str: ... + + @abstractmethod + def invalidate_remote_agent(self, name: str) -> None: ... + + @abstractmethod + def submit_transfer_requests(self, request: TransferRequest) -> TransferStatus: ... + + @abstractmethod + def notify_sync_message(self, name: str, sync_message: str) -> None: ... + + @abstractmethod + def check_remote_descs(self, name: str, memory_descs: List[int]) -> bool: ... + + +# RegMemoryDescs is Python-only (used for registration with name field) +@dataclass +class RegMemoryDescs: + type: str + descs: List[Tuple[int, int, int, str]] diff --git a/tensorrt_llm/_torch/disaggregation/base/kv_transfer.py b/tensorrt_llm/_torch/disaggregation/base/kv_transfer.py new file mode 100644 index 00000000000..ea0597eaa8d --- /dev/null +++ b/tensorrt_llm/_torch/disaggregation/base/kv_transfer.py @@ -0,0 +1,98 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from enum import Enum +from typing import List, Optional + +from tensorrt_llm import DisaggregatedParams + + +@dataclass +class KVSlice: + """Supports transmitting only part of the request cache, e.g, chunks or layers.""" + + start_token_idx: Optional[int] = None + end_token_idx: Optional[int] = None + start_layer: Optional[int] = None + end_layer: Optional[int] = None + blocks: List[int] = field(default_factory=list) + is_last_slice: bool = False + + +class SessionStatus(Enum): + INIT = "INIT" + READY = "READY" + TRANSFERRING = "TRANSFERRING" + TRANSFERRED = "TRANSFERRED" + AUX_TRANSFERRED = "AUX_TRANSFERRED" + COMPLETED = "COMPLETED" + CANCELED = "CANCELED" + ERROR = "ERROR" + + +TaskIdType = int + + +@dataclass +class SessionState: + status: SessionStatus + finished_tasks: List[TaskIdType] + + +@dataclass +class SessionArgsBase: + request_id: int + params: DisaggregatedParams + + +class SenderBase(ABC): ... + + +class ReceiverBase(ABC): ... + + +class TxSessionBase(ABC): + def __init__(self, sender: SenderBase, args: SessionArgsBase): + self._base_args = args + + @property + @abstractmethod + def state(self) -> SessionState: ... + + @abstractmethod + def poll_task(self, id: TaskIdType) -> SessionStatus: ... + + @abstractmethod + def send(self, slice: KVSlice) -> TaskIdType: ... + + """ + Async send slice to the peer. return the task id. Task state can be polled by poll_task(). + """ + + @property + @abstractmethod + def exception(self) -> Optional[Exception]: ... + + +class RxSessionBase(ABC): + def __init__(self, receiver: ReceiverBase, args: SessionArgsBase): + self._base_args = args + + @property + @abstractmethod + def state(self) -> SessionState: ... + + @abstractmethod + def poll_task(self, id: TaskIdType) -> SessionStatus: ... + + @abstractmethod + def receive(self, slice: KVSlice) -> TaskIdType: ... + + """ + Async receive slice from the peer. return the task id. Task state can be polled by poll_task(). + """ + + @property + @abstractmethod + def exception(self) -> Optional[Exception]: ... diff --git a/tensorrt_llm/_torch/disaggregation/native/__init__.py b/tensorrt_llm/_torch/disaggregation/native/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tensorrt_llm/_torch/disaggregation/native/messenger.py b/tensorrt_llm/_torch/disaggregation/native/messenger.py new file mode 100644 index 00000000000..d9e0d92d886 --- /dev/null +++ b/tensorrt_llm/_torch/disaggregation/native/messenger.py @@ -0,0 +1,189 @@ +from abc import ABC, abstractmethod +from threading import Event, Lock, Thread +from typing import Callable, Optional + +import zmq + +from tensorrt_llm import logger +from tensorrt_llm._torch.disaggregation.native.utils import get_local_ip + + +class MessengerInterface(ABC): + """ + Abstract base class for messenger implementations. + """ + + @abstractmethod + def start(self): + """ + Start the messenger service. + """ + ... + + @abstractmethod + def send(self, messages: list[bytes], recipient: Optional[bytes] = None): + """ + Send messages to a recipient. + :param messages: List of byte messages to send. + :param recipient: Optional recipient identifier. + """ + ... + + @abstractmethod + def send_encoded(self, *messages, encoding: str = "ascii"): + """ + Send messages after encoding them. + :param messages: Messages to send. + :param encoding: Encoding format. + """ + ... + + @abstractmethod + def receive(self) -> list[bytes]: + """ + Receive messages. + :return: List of byte messages received. + """ + ... + + @abstractmethod + def start_listener(self, on_message: Callable[[list[bytes]], None]): + """ + Start a listener thread to handle incoming messages. + :param on_message: Callback function to process received messages. + """ + ... + + @abstractmethod + def stop(self): + """ + Stop the messenger service. + """ + ... + + @property + @abstractmethod + def endpoint(self) -> str: + """ + Get the endpoint of the messenger. + :return: Endpoint string. + """ + ... + + +def decode_message(message: list[bytes], encoding: str = "ascii", errors: str = "strict") -> tuple: + if not isinstance(message, list) or not all(isinstance(m, bytes) for m in message): + raise ValueError("Input must be a list of bytes") + return tuple(m.decode(encoding, errors=errors) for m in message) + + +class ZMQMessenger(MessengerInterface): + def __init__(self, mode: str, endpoint: Optional[str] = f"tcp://{get_local_ip()}:*"): + self._context = zmq.Context() + self._mode = mode + self._socket = self._initialize_socket(mode) + self._closed = False + self._listener_thread: Optional[Thread] = None + self._stop_event = Event() + self._lock = Lock() + + self._internal_socket = self._context.socket(zmq.PAIR) + self._control_socket = self._context.socket(zmq.PAIR) + inproc_endpoint = "inproc://stop_listener" + self._control_socket.bind(inproc_endpoint) + self._internal_socket.connect(inproc_endpoint) + + if endpoint: + if mode in ["ROUTER", "REP"]: + self._socket.bind(endpoint) + self._endpoint = self._socket.getsockopt_string(zmq.LAST_ENDPOINT) + elif mode in ["DEALER", "REQ"]: + self._socket.connect(endpoint) + self._endpoint = endpoint + + logger.debug(f"Initializing ZMQMessenger, mode={mode}, endpoint={self._endpoint}") + + def _initialize_socket(self, mode): + if mode == "ROUTER": + return self._context.socket(zmq.ROUTER) + elif mode == "DEALER": + return self._context.socket(zmq.DEALER) + elif mode == "REQ": + return self._context.socket(zmq.REQ) + elif mode == "REP": + return self._context.socket(zmq.REP) + else: + raise ValueError(f"Unsupported ZeroMQ socket mode: {mode}") + + def start(self): + pass + + def send(self, messages: list[bytes], recipient: Optional[bytes] = None): + if recipient: + self._socket.send_multipart([recipient] + messages) + else: + self._socket.send_multipart(messages) + + def send_encoded(self, *messages, encoding: str = "ascii"): + encoded_messages = [str(message).encode(encoding) for message in messages] + self.send(encoded_messages) + + def receive(self) -> list[bytes]: + return self._socket.recv_multipart() + + def start_listener(self, on_message: Callable[[list[bytes]], None]): + if self._listener_thread and self._listener_thread.is_alive(): + raise RuntimeError("Listener already running") + + def listener(): + poller = zmq.Poller() + poller.register(self._socket, zmq.POLLIN) + poller.register(self._control_socket, zmq.POLLIN) + + while not self._stop_event.is_set(): + events = dict(poller.poll()) + try: + if self._control_socket in events: + self._stop_event.set() + elif self._socket in events: + messages = self.receive() + persist = on_message(messages) + if persist is False: + self._stop_event.set() + except zmq.ZMQError as e: + logger.error(f"ZMQ Error in listener: {e}") + continue + except Exception as e: + logger.error(f"Error in listener: {e}") + continue + + self._listener_thread = Thread(target=listener, daemon=True) + self._listener_thread.start() + + def stop(self, timeout=5): + with self._lock: + if self._closed: + return + self._closed = True + self._stop_event.set() + self._internal_socket.send(b"STOP") + if self._listener_thread: + self._listener_thread.join(timeout) + self._socket.close() + self._internal_socket.close() + self._control_socket.close() + self._context.term() + + @property + def endpoint(self) -> str: + assert self._endpoint is not None + return self._endpoint + + def __del__(self): + self.stop() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.stop() diff --git a/tensorrt_llm/_torch/disaggregation/native/utils.py b/tensorrt_llm/_torch/disaggregation/native/utils.py new file mode 100644 index 00000000000..b4d0f5a7391 --- /dev/null +++ b/tensorrt_llm/_torch/disaggregation/native/utils.py @@ -0,0 +1,34 @@ +def get_local_ip() -> str: + try: + import socket + + with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s: + s.connect(("8.8.8.8", 80)) + ip = s.getsockname()[0] + if not ip.startswith("127."): + return ip + except OSError: + pass + + try: + import netifaces + + for iface in netifaces.interfaces(): + addrs = netifaces.ifaddresses(iface) + if netifaces.AF_INET in addrs: + for addr in addrs[netifaces.AF_INET]: + ip = addr.get("addr", "") + if not ip.startswith("127.") and not ip.startswith("169.254"): + return ip + except Exception: + pass + + try: + hostname = socket.gethostname() + ip = socket.gethostbyname(hostname) + if not ip.startswith("127."): + return ip + except OSError: + pass + + return "127.0.0.1" diff --git a/tensorrt_llm/_torch/disaggregation/nixl/__init__.py b/tensorrt_llm/_torch/disaggregation/nixl/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tensorrt_llm/_torch/disaggregation/nixl/agent.py b/tensorrt_llm/_torch/disaggregation/nixl/agent.py new file mode 100644 index 00000000000..e165f9575ca --- /dev/null +++ b/tensorrt_llm/_torch/disaggregation/nixl/agent.py @@ -0,0 +1,253 @@ +"""NIXL Transfer Agent implementations. + +This module provides two implementations: +1. BindingsNixlTransferAgent - Uses the standalone nixl_bindings C++ module with GIL release support +2. NixlTransferAgent - Uses the Python nixl library directly (fallback) + +The standalone nixl_bindings module is separate from the main trtllm bindings, +so trtllm can still function normally even without NIXL dependencies. +""" + +import time + +from tensorrt_llm._utils import nvtx_range + +# Import base classes for type compatibility +from ..base.agent import BaseTransferAgent, RegMemoryDescs, TransferRequest, TransferStatus + +# Try to import the standalone tensorrt_llm_transfer_agent_binding module +# Located at tensorrt_llm/ (same level as bindings.so) +_AGENT_BINDING_AVAILABLE = False +try: + import tensorrt_llm.tensorrt_llm_transfer_agent_binding as _agent_binding # noqa: E402 + + _AGENT_BINDING_AVAILABLE = True + + # Import from standalone module + BaseAgentConfig = _agent_binding.BaseAgentConfig + CppNixlTransferAgent = _agent_binding.NixlTransferAgent + CppNixlTransferStatus = _agent_binding.NixlTransferStatus + MemoryType = _agent_binding.MemoryType + MemoryDescs = _agent_binding.MemoryDescs + AgentDesc = _agent_binding.AgentDesc + TransferState = _agent_binding.TransferState +except ImportError: + # tensorrt_llm_transfer_agent_binding not available, will fall back to Python nixl or raise error + pass + + +def is_transfer_agent_binding_available() -> bool: + """Check if the standalone tensorrt_llm_transfer_agent_binding module is available.""" + return _AGENT_BINDING_AVAILABLE + + +class BindingsNixlTransferStatus(TransferStatus): + """TransferStatus wrapper using C++ bindings with GIL release.""" + + def __init__(self, cpp_status): + self._cpp_status = cpp_status + + def is_completed(self) -> bool: + """Check if transfer is completed (releases GIL).""" + return self._cpp_status.is_completed() + + @nvtx_range("BindingsNixlTransferStatus.wait") + def wait(self) -> bool: + """Wait for transfer to complete (releases GIL).""" + return self._cpp_status.wait() == TransferState.SUCCESS + + +class BindingsNixlTransferAgent(BaseTransferAgent): + """NixlTransferAgent using C++ bindings with GIL release support. + + This implementation uses the standalone nixl_bindings C++ module which releases + the GIL during blocking operations like wait(). + + The nixl_bindings module is independent from the main trtllm bindings, + so trtllm can function normally even without NIXL. + """ + + def __init__(self, name: str, use_prog_thread: bool = True, num_workers: int = 1): + if not _AGENT_BINDING_AVAILABLE: + raise ImportError( + "tensorrt_llm_transfer_agent_binding module is not available. " + "Please build with NIXL_ROOT set to enable NIXL support." + ) + config = BaseAgentConfig( + name, + use_prog_thread, + multi_thread=False, + use_listen_thread=False, + num_workers=num_workers, + ) + self._cpp_agent = CppNixlTransferAgent(config) + self.name = name + + def register_memory(self, descs: RegMemoryDescs): + """Register memory regions.""" + cpp_descs = self._convert_reg_memory_descs(descs) + self._cpp_agent.register_memory(cpp_descs) + + def deregister_memory(self, descs: RegMemoryDescs): + """Deregister memory regions.""" + cpp_descs = self._convert_reg_memory_descs(descs) + self._cpp_agent.deregister_memory(cpp_descs) + + def load_remote_agent(self, name: str, agent_desc: bytes): + """Load a remote agent by its descriptor (bytes).""" + # AgentDesc expects std::string which can hold binary data + desc_str = agent_desc if isinstance(agent_desc, bytes) else agent_desc.encode() + cpp_desc = AgentDesc(desc_str) + self._cpp_agent.load_remote_agent(name, cpp_desc) + + def load_remote_agent_by_connection(self, name: str, connection_info: str): + """Load a remote agent by connection info.""" + self._cpp_agent.load_remote_agent_by_connection(name, connection_info) + + def get_local_agent_desc(self) -> bytes: + """Get the local agent descriptor as bytes.""" + agent_desc = self._cpp_agent.get_local_agent_desc() + return agent_desc.backend_agent_desc # Returns bytes + + def get_local_connection_info(self) -> str: + """Get the local connection info.""" + return self._cpp_agent.get_local_connection_info() + + def invalidate_remote_agent(self, name: str): + """Invalidate a remote agent.""" + self._cpp_agent.invalidate_remote_agent(name) + + def check_remote_descs(self, name: str, memory_descs: MemoryDescs) -> bool: + """Check if remote descriptors are available. + + memory_descs should be C++ MemoryDescs type. + """ + return self._cpp_agent.check_remote_descs(name, memory_descs) + + def notify_sync_message(self, name: str, sync_message: str): + """Send a sync message to a remote agent.""" + self._cpp_agent.notify_sync_message(name, sync_message) + + def get_notified_sync_messages(self): + """Get notified sync messages.""" + return self._cpp_agent.get_notified_sync_messages() + + @nvtx_range("BindingsNixlTransferAgent.submit_transfer_requests") + def submit_transfer_requests(self, request: TransferRequest) -> TransferStatus: + """Submit transfer requests and return status. + + request should be a C++ TransferRequest (from tensorrt_llm_transfer_agent_binding). + """ + cpp_status = self._cpp_agent.submit_transfer_requests(request) + return BindingsNixlTransferStatus(cpp_status) + + def _convert_reg_memory_descs(self, descs: RegMemoryDescs) -> "MemoryDescs": + """Convert Python RegMemoryDescs to C++ MemoryDescs. + + RegMemoryDescs.descs is List[Tuple[int, int, int, str]] = (ptr, size, device_id, name) + Extract first 3 elements for C++ batch constructor. + """ + mem_type = self._convert_memory_type(descs.type) + # Extract (ptr, size, device_id) from 4-tuple, discard name + tuples = [(d[0], d[1], d[2]) for d in descs.descs] + return MemoryDescs(mem_type, tuples) + + def _convert_memory_type(self, py_type: str) -> "MemoryType": + """Convert Python memory type string to C++ MemoryType.""" + type_map = { + "DRAM": MemoryType.DRAM, + "VRAM": MemoryType.VRAM, + "GPU": MemoryType.VRAM, + "BLK": MemoryType.BLK, + "OBJ": MemoryType.OBJ, + "FILE": MemoryType.FILE, + } + return type_map.get(py_type.upper(), MemoryType.VRAM) + + +# For backward compatibility, also keep the Python nixl-based implementation +NixlTransferAgent = None +NixlTransferStatus = None + +try: + from nixl import nixl_agent, nixl_agent_config, nixl_xfer_handle # noqa: E402 + + class NixlTransferStatus(TransferStatus): + """TransferStatus using Python nixl library.""" + + def __init__(self, agent: nixl_agent, handle: nixl_xfer_handle): + self.agent = agent + self.handle = handle + + def is_completed(self): + status = self.agent.check_xfer_state(self.handle) + return status == "DONE" + + def wait(self): + status = "PROC" + sleep_time = 0.0001 # 0.1ms + max_sleep_time = 0.01 # 10ms + while status == "PROC": + status = self.agent.check_xfer_state(self.handle) + if status == "ERROR": + return False # transfer failed + # sleep(0.1) + # sleep to release GIL + time.sleep(sleep_time) + sleep_time = min(sleep_time * 2, max_sleep_time) + return True + + class NixlTransferAgent(BaseTransferAgent): + """NixlTransferAgent using Python nixl library.""" + + def __init__(self, name: str, use_prog_thread: bool, num_workers: int = 1): + self.name = name + agent_config = nixl_agent_config( + enable_prog_thread=use_prog_thread, backends=["UCX"], num_threads=num_workers + ) + self.agent = nixl_agent(name, agent_config) + + def register_memory(self, descs: RegMemoryDescs): + reg_descs = self.agent.get_reg_descs(descs.descs, descs.type) + self.agent.register_memory(reg_descs) + + def deregister_memory(self, descs: RegMemoryDescs): + self.agent.deregister_memory(descs.descs, descs.type) + + def load_remote_agent(self, name: str, agent_desc: bytes): + self.agent.add_remote_agent(agent_desc) + + def get_local_agent_desc(self): + return self.agent.get_agent_metadata() + + def invalidate_remote_agent(self, name: str): + self.agent.remove_remote_agent(name) + + def check_remote_descs(self, name: str, memory_descs: list[int]) -> bool: + raise NotImplementedError + + def notify_sync_message(self, name: str, sync_message: str): + raise NotImplementedError + + @nvtx_range("NixlTransferAgent.submit_transfer_requests") + def submit_transfer_requests(self, request: TransferRequest) -> TransferStatus: + src_xfer_descs = self.agent.get_xfer_descs( + request.src_descs.descs, request.src_descs.type + ) + dst_xfer_descs = self.agent.get_xfer_descs( + request.dst_descs.descs, request.dst_descs.type + ) + handle = self.agent.initialize_xfer( + request.op, + src_xfer_descs, + dst_xfer_descs, + request.remote_name, + request.sync_message, + ) + status = self.agent.transfer(handle) + assert status != "ERROR" + return NixlTransferStatus(self.agent, handle) + +except ImportError: + # nixl library not available + pass From cd618b93d7a56f7d54a6a6f3eadc0edbc8de2e52 Mon Sep 17 00:00:00 2001 From: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com> Date: Thu, 8 Jan 2026 10:35:01 +0000 Subject: [PATCH 2/3] update tests Signed-off-by: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com> --- .../_torch/disaggregation/base/agent.py | 71 +++++++--- .../_torch/disaggregation/base/kv_transfer.py | 84 +++++++++--- .../_torch/disaggregation/native/messenger.py | 117 ++++++++++------ .../unittest/disaggregated/test_messenger.py | 127 ++++++++++++++++++ 4 files changed, 320 insertions(+), 79 deletions(-) create mode 100644 tests/unittest/disaggregated/test_messenger.py diff --git a/tensorrt_llm/_torch/disaggregation/base/agent.py b/tensorrt_llm/_torch/disaggregation/base/agent.py index b677f24a489..dd55ad08a8d 100644 --- a/tensorrt_llm/_torch/disaggregation/base/agent.py +++ b/tensorrt_llm/_torch/disaggregation/base/agent.py @@ -1,23 +1,28 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import List, Tuple, Union +from typing import Enum, List, Tuple, Union + +from tensorrt_llm.logger import logger # Try to import C++ bindings for zero-copy performance -_CPP_BINDING_AVAILABLE = False try: - import tensorrt_llm.tensorrt_llm_transfer_agent_binding as _cpp_binding + from tensorrt_llm.tensorrt_llm_transfer_agent_binding import ( + BaseTransferAgent, + MemoryDesc, + MemoryDescs, + MemoryType, + TransferOp, + TransferRequest, + TransferStatus, + ) _CPP_BINDING_AVAILABLE = True - # Use C++ types directly when available - MemoryType = _cpp_binding.MemoryType - TransferOp = _cpp_binding.TransferOp - MemoryDesc = _cpp_binding.MemoryDesc - MemoryDescs = _cpp_binding.MemoryDescs - TransferRequest = _cpp_binding.TransferRequest - TransferStatus = _cpp_binding.TransferStatus - BaseTransferAgent = _cpp_binding.BaseTransferAgent except ImportError: _CPP_BINDING_AVAILABLE = False + logger.warning( + "C++ transfer agent bindings not available. " + "Falling back to Python implementations which may have lower performance." + ) def is_cpp_binding_available() -> bool: @@ -28,11 +33,11 @@ def is_cpp_binding_available() -> bool: # Fallback Python implementations when C++ bindings not available if not _CPP_BINDING_AVAILABLE: - class TransferOp: + class TransferOp(Enum): READ = "READ" WRITE = "WRITE" - class MemoryType: + class MemoryType(Enum): DRAM = "DRAM" VRAM = "VRAM" BLK = "BLK" @@ -67,28 +72,52 @@ def wait(self, timeout: float | None = None) -> None: ... class BaseTransferAgent(ABC): @abstractmethod - def register_memory(self, descs: MemoryDescs) -> None: ... + def register_memory(self, descs: MemoryDescs) -> None: + """Register a set of memory descriptors on the agent.""" + ... @abstractmethod - def deregister_memory(self, descs: MemoryDescs) -> None: ... + def deregister_memory(self, descs: MemoryDescs) -> None: + """De-register a set of memory descriptors on the agent.""" + ... @abstractmethod - def load_remote_agent(self, name: str, agent_desc: str) -> None: ... + def load_remote_agent(self, name: str, agent_desc: str) -> None: + """ + Load information about a remote agent specified by name. + + Args: + name (str): The remote agent's identifier. + agent_desc (str): A serialized description of the agent. + """ + ... @abstractmethod - def get_local_agent_desc(self) -> str: ... + def get_local_agent_desc(self) -> str: + """Return the serialized description of this agent.""" + ... @abstractmethod - def invalidate_remote_agent(self, name: str) -> None: ... + def invalidate_remote_agent(self, name: str) -> None: + """Invalidate any cached information about the specified remote agent.""" + ... @abstractmethod - def submit_transfer_requests(self, request: TransferRequest) -> TransferStatus: ... + def submit_transfer_requests(self, request: TransferRequest) -> TransferStatus: + """Submit transfer tasks to the agent based on a request.""" + ... @abstractmethod - def notify_sync_message(self, name: str, sync_message: str) -> None: ... + def notify_sync_message(self, name: str, sync_message: str) -> None: + """Send a synchronization message to the specified remote agent.""" + ... @abstractmethod - def check_remote_descs(self, name: str, memory_descs: List[int]) -> bool: ... + def check_remote_descs(self, name: str, memory_descs: List[int]) -> bool: + """ + Verify the remote agent's memory descriptors. + """ + ... # RegMemoryDescs is Python-only (used for registration with name field) diff --git a/tensorrt_llm/_torch/disaggregation/base/kv_transfer.py b/tensorrt_llm/_torch/disaggregation/base/kv_transfer.py index ea0597eaa8d..e45f4856f0e 100644 --- a/tensorrt_llm/_torch/disaggregation/base/kv_transfer.py +++ b/tensorrt_llm/_torch/disaggregation/base/kv_transfer.py @@ -21,6 +21,8 @@ class KVSlice: class SessionStatus(Enum): + """Status of a transfer session.""" + INIT = "INIT" READY = "READY" TRANSFERRING = "TRANSFERRING" @@ -36,63 +38,107 @@ class SessionStatus(Enum): @dataclass class SessionState: + """State of a transfer session.""" + status: SessionStatus finished_tasks: List[TaskIdType] @dataclass class SessionArgsBase: - request_id: int + """Base arguments for transfer sessions.""" + params: DisaggregatedParams -class SenderBase(ABC): ... +class SenderBase(ABC): + """Base class for sending KV cache data.""" + + ... -class ReceiverBase(ABC): ... +class ReceiverBase(ABC): + """Base class for receiving KV cache data.""" + + ... class TxSessionBase(ABC): def __init__(self, sender: SenderBase, args: SessionArgsBase): + """ + Initializes the transmission session. + :param sender: The sender instance responsible for sending data. + :param args: The session arguments. + """ self._base_args = args @property @abstractmethod - def state(self) -> SessionState: ... + def state(self) -> SessionState: + """ + Returns the current state of the session. + """ + ... @abstractmethod - def poll_task(self, id: TaskIdType) -> SessionStatus: ... + def poll_task(self, id: TaskIdType) -> SessionStatus: + """ + Polls the status of a specific task by its ID. + :param id: The task ID to poll. + """ + ... @abstractmethod - def send(self, slice: KVSlice) -> TaskIdType: ... - - """ - Async send slice to the peer. return the task id. Task state can be polled by poll_task(). - """ + def send(self, slice: KVSlice) -> TaskIdType: + """ + Sends a slice of KV cache data and returns the task ID. + :param slice: The KV slice to send. + """ + ... @property @abstractmethod - def exception(self) -> Optional[Exception]: ... + def exception(self) -> Optional[Exception]: + """ + Returns any exception that occurred during the session. + """ + ... class RxSessionBase(ABC): def __init__(self, receiver: ReceiverBase, args: SessionArgsBase): + """ + Initializes the reception session. + :param receiver: The receiver instance responsible for receiving data. + """ self._base_args = args @property @abstractmethod - def state(self) -> SessionState: ... + def state(self) -> SessionState: + """ + Returns the current state of the session. + """ + ... @abstractmethod - def poll_task(self, id: TaskIdType) -> SessionStatus: ... + def poll_task(self, id: TaskIdType) -> SessionStatus: + """ + Polls the status of a specific task by its ID. + :param id: The task ID to poll. + """ + ... @abstractmethod - def receive(self, slice: KVSlice) -> TaskIdType: ... - - """ - Async receive slice from the peer. return the task id. Task state can be polled by poll_task(). - """ + def receive(self, slice: KVSlice) -> TaskIdType: + """ + Receives a slice of KV cache data and returns the task ID. + :param slice: The KV slice to receive. + """ + ... @property @abstractmethod - def exception(self) -> Optional[Exception]: ... + def exception(self) -> Optional[Exception]: + """Returns any exception that occurred during the session.""" + ... diff --git a/tensorrt_llm/_torch/disaggregation/native/messenger.py b/tensorrt_llm/_torch/disaggregation/native/messenger.py index d9e0d92d886..37e857c6006 100644 --- a/tensorrt_llm/_torch/disaggregation/native/messenger.py +++ b/tensorrt_llm/_torch/disaggregation/native/messenger.py @@ -71,50 +71,52 @@ def endpoint(self) -> str: ... -def decode_message(message: list[bytes], encoding: str = "ascii", errors: str = "strict") -> tuple: +def decode_message( + message: list[bytes], encoding: str = "ascii", err_mode: str = "strict" +) -> tuple: if not isinstance(message, list) or not all(isinstance(m, bytes) for m in message): raise ValueError("Input must be a list of bytes") - return tuple(m.decode(encoding, errors=errors) for m in message) + return tuple(m.decode(encoding, errors=err_mode) for m in message) class ZMQMessenger(MessengerInterface): - def __init__(self, mode: str, endpoint: Optional[str] = f"tcp://{get_local_ip()}:*"): + SOCKET_MODES = { + "ROUTER": zmq.ROUTER, # Handles multiple connections and routes messages by address. + "DEALER": zmq.DEALER, # Load balances outgoing messages and receives replies fairly. + "REQ": zmq.REQ, # Sends requests and waits for replies (synchronous). + "REP": zmq.REP, # Receives requests and sends replies (synchronous). + } + + def __init__(self, mode: str, endpoint: Optional[str] = None): self._context = zmq.Context() self._mode = mode - self._socket = self._initialize_socket(mode) + self._socket = self._context.socket(self.SOCKET_MODES[mode]) + self._endpoint: Optional[str] = None + self._lock = Lock() self._closed = False - self._listener_thread: Optional[Thread] = None self._stop_event = Event() - self._lock = Lock() + self._listener_thread: Optional[Thread] = None + self._initialize_control_sockets() - self._internal_socket = self._context.socket(zmq.PAIR) + if endpoint is None: + endpoint = f"tcp://{get_local_ip()}:*" + + if mode in ["ROUTER", "REP"]: + self._socket.bind(endpoint) + self._endpoint = self._socket.getsockopt_string(zmq.LAST_ENDPOINT) + elif mode in ["DEALER", "REQ"]: + self._socket.connect(endpoint) + self._endpoint = endpoint + + logger.info(f"Initialized ZMQMessenger(mode={mode}, endpoint={self._endpoint})") + + def _initialize_control_sockets(self): self._control_socket = self._context.socket(zmq.PAIR) + self._internal_socket = self._context.socket(zmq.PAIR) inproc_endpoint = "inproc://stop_listener" self._control_socket.bind(inproc_endpoint) self._internal_socket.connect(inproc_endpoint) - if endpoint: - if mode in ["ROUTER", "REP"]: - self._socket.bind(endpoint) - self._endpoint = self._socket.getsockopt_string(zmq.LAST_ENDPOINT) - elif mode in ["DEALER", "REQ"]: - self._socket.connect(endpoint) - self._endpoint = endpoint - - logger.debug(f"Initializing ZMQMessenger, mode={mode}, endpoint={self._endpoint}") - - def _initialize_socket(self, mode): - if mode == "ROUTER": - return self._context.socket(zmq.ROUTER) - elif mode == "DEALER": - return self._context.socket(zmq.DEALER) - elif mode == "REQ": - return self._context.socket(zmq.REQ) - elif mode == "REP": - return self._context.socket(zmq.REP) - else: - raise ValueError(f"Unsupported ZeroMQ socket mode: {mode}") - def start(self): pass @@ -131,7 +133,14 @@ def send_encoded(self, *messages, encoding: str = "ascii"): def receive(self) -> list[bytes]: return self._socket.recv_multipart() - def start_listener(self, on_message: Callable[[list[bytes]], None]): + def start_listener( + self, + on_message: Callable[[list[bytes]], None], + on_error: Optional[Callable[[Exception], None]] = None, + ): + assert self._mode in ["ROUTER", "REP"], ( + "Listener can only be started in ROUTER or REP modes" + ) if self._listener_thread and self._listener_thread.is_alive(): raise RuntimeError("Listener already running") @@ -147,32 +156,62 @@ def listener(): self._stop_event.set() elif self._socket in events: messages = self.receive() - persist = on_message(messages) - if persist is False: - self._stop_event.set() + try: + persist = on_message(messages) + if persist is False: + self._stop_event.set() + except Exception as e: + logger.error(f"Error in on_message callback: {e}") + if on_error: + on_error(e) + else: + self._stop_event.set() except zmq.ZMQError as e: logger.error(f"ZMQ Error in listener: {e}") - continue + if on_error: + on_error(e) + break except Exception as e: logger.error(f"Error in listener: {e}") - continue + if on_error: + on_error(e) + break + + self._stop_event.set() self._listener_thread = Thread(target=listener, daemon=True) self._listener_thread.start() def stop(self, timeout=5): + def _close_socket(socket: zmq.Socket): + try: + if not socket.closed: + socket.close() + except Exception as e: + logger.error(f"Error closing socket: {e}") + with self._lock: if self._closed: return self._closed = True + logger.debug("Stopping ZMQMessenger...") + self._stop_event.set() self._internal_socket.send(b"STOP") if self._listener_thread: self._listener_thread.join(timeout) - self._socket.close() - self._internal_socket.close() - self._control_socket.close() - self._context.term() + if self._listener_thread.is_alive(): + logger.warning("Listener thread did not terminate within timeout") + + _close_socket(self._socket) + _close_socket(self._internal_socket) + _close_socket(self._control_socket) + + try: + if self._context.closed: + self._context.term() + except Exception as e: + logger.error(f"Error terminating ZMQ context: {e}") @property def endpoint(self) -> str: diff --git a/tests/unittest/disaggregated/test_messenger.py b/tests/unittest/disaggregated/test_messenger.py new file mode 100644 index 00000000000..94e052d531e --- /dev/null +++ b/tests/unittest/disaggregated/test_messenger.py @@ -0,0 +1,127 @@ +import socket +import time +import unittest + +import pytest +from parameterized import parameterized + +from tensorrt_llm._torch.disaggregation.native.messenger import ZMQMessenger, decode_message +from tensorrt_llm._torch.disaggregation.native.utils import get_local_ip + +TEST_CASES = [ + { + "name": "valid_message", + "message": [b"hello", b"world"], + "encoding": "utf-8", + "err_mode": "strict", + "expected": ("hello", "world"), + "raises": None, + }, + { + "name": "invalid_input", + "message": ["hello", b"world"], + "encoding": "utf-8", + "err_mode": "strict", + "expected": None, + "raises": ValueError, + }, + { + "name": "decoding_error", + "message": [b"\xff"], + "encoding": "utf-8", + "err_mode": "strict", + "expected": None, + "raises": UnicodeDecodeError, + }, + { + "name": "decoding_with_ignore", + "message": [b"\xff"], + "encoding": "utf-8", + "err_mode": "ignore", + "expected": ("",), + "raises": None, + }, +] + + +class TestDecodeMessage(unittest.TestCase): + @parameterized.expand([(case["name"], case) for case in TEST_CASES]) + def test_decode_message(self, name, case): + message = case["message"] + encoding = case["encoding"] + err_mode = case["err_mode"] + expected = case["expected"] + raises = case["raises"] + + if raises: + with self.assertRaises(raises): + decode_message(message, encoding=encoding, err_mode=err_mode) + else: + decoded = decode_message(message, encoding=encoding, err_mode=err_mode) + self.assertEqual(decoded, expected) + + +@pytest.fixture +def dynamic_endpoint(): + """Fixture to dynamically generate an available endpoint with a free port.""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) # Bind to an available port provided by the OS + port = s.getsockname()[1] + return f"tcp://{get_local_ip()}:{port}" + + +@pytest.fixture +def create_messenger_pair(dynamic_endpoint): + def _create_messenger_pair(mode1, mode2): + messenger1 = ZMQMessenger( + mode1, endpoint=dynamic_endpoint if mode1 in ["ROUTER", "REP"] else None + ) + messenger2 = ZMQMessenger( + mode2, endpoint=dynamic_endpoint if mode2 in ["DEALER", "REQ"] else None + ) + return messenger1, messenger2 + + yield _create_messenger_pair + + +def test_router_dealer(create_messenger_pair): + """Test ROUTER and DEALER communication.""" + router, dealer = create_messenger_pair("ROUTER", "DEALER") + + received_messages = [] + + def on_message(messages): + received_messages.extend(messages) + + router.start_listener(on_message) + + dealer.send([b"Hello, ROUTER!"]) + + time.sleep(0.1) + + assert len(received_messages) > 0 + assert b"Hello, ROUTER!" in received_messages + + router.stop() + dealer.stop() + + +def test_req_rep(create_messenger_pair): + """Test REQ and REP communication.""" + rep, req = create_messenger_pair("REP", "REQ") + + def on_message(messages): + rep.send(messages) + + rep.start_listener(on_message) + + req.send([b"Hello, REP!"]) + response = req.receive() + assert response == [b"Hello, REP!"] + + req.stop() + rep.stop() + + +if __name__ == "__main__": + unittest.main() From 5cca4a0edf4d9df25863db4bdbff95bce872de21 Mon Sep 17 00:00:00 2001 From: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com> Date: Thu, 8 Jan 2026 13:40:51 +0000 Subject: [PATCH 3/3] refactor the agent Signed-off-by: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com> --- .../_torch/disaggregation/base/agent.py | 193 +++++--------- .../_torch/disaggregation/native/messenger.py | 2 +- .../_torch/disaggregation/nixl/_agent_cpp.py | 108 ++++++++ .../disaggregation/nixl/_agent_python.py | 81 ++++++ .../_torch/disaggregation/nixl/agent.py | 246 +----------------- 5 files changed, 271 insertions(+), 359 deletions(-) create mode 100644 tensorrt_llm/_torch/disaggregation/nixl/_agent_cpp.py create mode 100644 tensorrt_llm/_torch/disaggregation/nixl/_agent_python.py diff --git a/tensorrt_llm/_torch/disaggregation/base/agent.py b/tensorrt_llm/_torch/disaggregation/base/agent.py index dd55ad08a8d..bc9f10a2a72 100644 --- a/tensorrt_llm/_torch/disaggregation/base/agent.py +++ b/tensorrt_llm/_torch/disaggregation/base/agent.py @@ -2,126 +2,77 @@ from dataclasses import dataclass from typing import Enum, List, Tuple, Union -from tensorrt_llm.logger import logger - -# Try to import C++ bindings for zero-copy performance -try: - from tensorrt_llm.tensorrt_llm_transfer_agent_binding import ( - BaseTransferAgent, - MemoryDesc, - MemoryDescs, - MemoryType, - TransferOp, - TransferRequest, - TransferStatus, - ) - - _CPP_BINDING_AVAILABLE = True -except ImportError: - _CPP_BINDING_AVAILABLE = False - logger.warning( - "C++ transfer agent bindings not available. " - "Falling back to Python implementations which may have lower performance." - ) - - -def is_cpp_binding_available() -> bool: - """Check if C++ transfer agent bindings are available.""" - return _CPP_BINDING_AVAILABLE - - -# Fallback Python implementations when C++ bindings not available -if not _CPP_BINDING_AVAILABLE: - - class TransferOp(Enum): - READ = "READ" - WRITE = "WRITE" - - class MemoryType(Enum): - DRAM = "DRAM" - VRAM = "VRAM" - BLK = "BLK" - OBJ = "OBJ" - FILE = "FILE" - - @dataclass - class MemoryDesc: - ptr: int - size: int - device_id: int - - @dataclass - class MemoryDescs: - type: str - descs: List[Union[Tuple[int, int, int], MemoryDesc]] - - @dataclass - class TransferRequest: - op: TransferOp - src_descs: MemoryDescs - dst_descs: MemoryDescs - remote_name: str - sync_message: str - - class TransferStatus(ABC): - @abstractmethod - def is_completed(self) -> bool: ... - - @abstractmethod - def wait(self, timeout: float | None = None) -> None: ... - - class BaseTransferAgent(ABC): - @abstractmethod - def register_memory(self, descs: MemoryDescs) -> None: - """Register a set of memory descriptors on the agent.""" - ... - - @abstractmethod - def deregister_memory(self, descs: MemoryDescs) -> None: - """De-register a set of memory descriptors on the agent.""" - ... - - @abstractmethod - def load_remote_agent(self, name: str, agent_desc: str) -> None: - """ - Load information about a remote agent specified by name. - - Args: - name (str): The remote agent's identifier. - agent_desc (str): A serialized description of the agent. - """ - ... - - @abstractmethod - def get_local_agent_desc(self) -> str: - """Return the serialized description of this agent.""" - ... - - @abstractmethod - def invalidate_remote_agent(self, name: str) -> None: - """Invalidate any cached information about the specified remote agent.""" - ... - - @abstractmethod - def submit_transfer_requests(self, request: TransferRequest) -> TransferStatus: - """Submit transfer tasks to the agent based on a request.""" - ... - - @abstractmethod - def notify_sync_message(self, name: str, sync_message: str) -> None: - """Send a synchronization message to the specified remote agent.""" - ... - - @abstractmethod - def check_remote_descs(self, name: str, memory_descs: List[int]) -> bool: - """ - Verify the remote agent's memory descriptors. - """ - ... - - -# RegMemoryDescs is Python-only (used for registration with name field) + +# Common Enumerations +class TransferOp(Enum): + READ = "READ" + WRITE = "WRITE" + + +class MemoryType(Enum): + DRAM = "DRAM" + VRAM = "VRAM" + BLK = "BLK" + OBJ = "OBJ" + FILE = "FILE" + + +# Common Data Structures @dataclass -class RegMemoryDescs: +class MemoryDesc: + ptr: int + size: int + device_id: int + + +@dataclass +class MemoryDescs: type: str - descs: List[Tuple[int, int, int, str]] + descs: List[Union[Tuple[int, int, int], MemoryDesc]] + + +@dataclass +class TransferRequest: + op: TransferOp + src_descs: MemoryDescs + dst_descs: MemoryDescs + remote_name: str + sync_message: str + + +class BaseTransferStatus(ABC): + """Abstract base class for transfer status.""" + + @abstractmethod + def is_completed(self) -> bool: ... + + @abstractmethod + def wait(self, timeout: float | None = None) -> None: ... + + +class BaseTransferAgent(ABC): + """Abstract base class for transfer agents.""" + + @abstractmethod + def register_memory(self, descs: MemoryDescs) -> None: ... + + @abstractmethod + def deregister_memory(self, descs: MemoryDescs) -> None: ... + + @abstractmethod + def load_remote_agent(self, name: str, agent_desc: str) -> None: ... + + @abstractmethod + def get_local_agent_desc(self) -> str: ... + + @abstractmethod + def invalidate_remote_agent(self, name: str) -> None: ... + + @abstractmethod + def submit_transfer_requests(self, request: TransferRequest) -> BaseTransferStatus: ... + + @abstractmethod + def notify_sync_message(self, name: str, sync_message: str) -> None: ... + + @abstractmethod + def check_remote_descs(self, name: str, memory_descs: List[int]) -> bool: ... diff --git a/tensorrt_llm/_torch/disaggregation/native/messenger.py b/tensorrt_llm/_torch/disaggregation/native/messenger.py index 37e857c6006..ed72cc04446 100644 --- a/tensorrt_llm/_torch/disaggregation/native/messenger.py +++ b/tensorrt_llm/_torch/disaggregation/native/messenger.py @@ -150,7 +150,7 @@ def listener(): poller.register(self._control_socket, zmq.POLLIN) while not self._stop_event.is_set(): - events = dict(poller.poll()) + events = dict(poller.poll(timeout=100)) try: if self._control_socket in events: self._stop_event.set() diff --git a/tensorrt_llm/_torch/disaggregation/nixl/_agent_cpp.py b/tensorrt_llm/_torch/disaggregation/nixl/_agent_cpp.py new file mode 100644 index 00000000000..214fe929838 --- /dev/null +++ b/tensorrt_llm/_torch/disaggregation/nixl/_agent_cpp.py @@ -0,0 +1,108 @@ +from tensorrt_llm._utils import nvtx_range +from tensorrt_llm.tensorrt_llm_transfer_agent_binding import ( + AgentDesc, + BaseAgentConfig, + MemoryDescs, + MemoryType, + TransferState, +) +from tensorrt_llm.tensorrt_llm_transfer_agent_binding import ( + NixlTransferAgent as CppNixlTransferAgent, +) +from tensorrt_llm.tensorrt_llm_transfer_agent_binding import ( + NixlTransferStatus as CppNixlTransferStatus, +) + +from ..base.agent import BaseTransferAgent, BaseTransferStatus, RegMemoryDescs, TransferRequest + + +class NixlTransferStatus(BaseTransferStatus): + def __init__(self, cpp_status: CppNixlTransferStatus): + self._cpp_status = cpp_status + + def is_completed(self) -> bool: + """Check if transfer is completed (releases GIL).""" + return self._cpp_status.is_completed() + + @nvtx_range("NixlTransferStatus.wait") + def wait(self, timeout: float = None) -> bool: + """Wait for transfer to complete (releases GIL).""" + return self._cpp_status.wait() == TransferState.SUCCESS + + +class NixlTransferAgent(BaseTransferAgent): + """NixlTransferAgent using C++ bindings with GIL release support. + + This implementation uses the standalone nixl_bindings C++ module which releases + the GIL during blocking operations like wait(). + + The nixl_bindings module is independent from the main trtllm bindings, + so trtllm can function normally even without NIXL. + """ + + def __init__(self, name: str, use_prog_thread: bool = True, num_workers: int = 1): + config = BaseAgentConfig( + name=name, + use_prog_thread=use_prog_thread, + multi_thread=False, + use_listen_thread=False, + num_workers=num_workers, + ) + self._cpp_agent = CppNixlTransferAgent(config) + self.name = name + + def register_memory(self, descs: RegMemoryDescs): + cpp_descs = self._convert_reg_memory_descs(descs) + self._cpp_agent.register_memory(cpp_descs) + + def deregister_memory(self, descs: RegMemoryDescs): + cpp_descs = self._convert_reg_memory_descs(descs) + self._cpp_agent.deregister_memory(cpp_descs) + + def load_remote_agent(self, name: str, agent_desc: bytes): + desc_str = agent_desc if isinstance(agent_desc, bytes) else agent_desc.encode() + cpp_desc = AgentDesc(desc_str) + self._cpp_agent.load_remote_agent(name, cpp_desc) + + def load_remote_agent_by_connection(self, name: str, connection_info: str): + self._cpp_agent.load_remote_agent_by_connection(name, connection_info) + + def get_local_agent_desc(self) -> bytes: + agent_desc = self._cpp_agent.get_local_agent_desc() + return agent_desc.backend_agent_desc + + def get_local_connection_info(self) -> str: + return self._cpp_agent.get_local_connection_info() + + def invalidate_remote_agent(self, name: str): + self._cpp_agent.invalidate_remote_agent(name) + + def check_remote_descs(self, name: str, memory_descs: MemoryDescs) -> bool: + return self._cpp_agent.check_remote_descs(name, memory_descs) + + def notify_sync_message(self, name: str, sync_message: str): + self._cpp_agent.notify_sync_message(name, sync_message) + + def get_notified_sync_messages(self): + return self._cpp_agent.get_notified_sync_messages() + + @nvtx_range("BindingsNixlTransferAgent.submit_transfer_requests") + def submit_transfer_requests(self, request: TransferRequest) -> BaseTransferStatus: + cpp_status = self._cpp_agent.submit_transfer_requests(request) + return NixlTransferStatus(cpp_status) + + def _convert_reg_memory_descs(self, descs: RegMemoryDescs) -> "MemoryDescs": + mem_type = self._convert_memory_type(descs.type) + tuples = [(d[0], d[1], d[2]) for d in descs.descs] # Extract (ptr, size, device_id) + return MemoryDescs(mem_type, tuples) + + def _convert_memory_type(self, py_type: str) -> "MemoryType": + type_map = { + "DRAM": MemoryType.DRAM, + "VRAM": MemoryType.VRAM, + "GPU": MemoryType.VRAM, + "BLK": MemoryType.BLK, + "OBJ": MemoryType.OBJ, + "FILE": MemoryType.FILE, + } + return type_map.get(py_type.upper(), MemoryType.VRAM) diff --git a/tensorrt_llm/_torch/disaggregation/nixl/_agent_python.py b/tensorrt_llm/_torch/disaggregation/nixl/_agent_python.py new file mode 100644 index 00000000000..2466e924170 --- /dev/null +++ b/tensorrt_llm/_torch/disaggregation/nixl/_agent_python.py @@ -0,0 +1,81 @@ +import time + +from nixl import nixl_agent, nixl_agent_config, nixl_xfer_handle + +from tensorrt_llm._utils import nvtx_range + +from ..base.agent import BaseTransferAgent, BaseTransferStatus, RegMemoryDescs, TransferRequest + + +class NixlTransferStatus(BaseTransferStatus): + """TransferStatus using the Python NIXL library.""" + + def __init__(self, agent: nixl_agent, handle: nixl_xfer_handle): + self.agent = agent + self.handle = handle + + def is_completed(self) -> bool: + status = self.agent.check_xfer_state(self.handle) + return status == "DONE" + + def wait(self) -> bool: + status = "PROC" + sleep_time = 0.0001 # 0.1ms + max_sleep_time = 0.01 # 10ms + while status == "PROC": + status = self.agent.check_xfer_state(self.handle) + if status == "ERROR": + return False # Transfer failed + time.sleep(sleep_time) # Sleep to release GIL + sleep_time = min(sleep_time * 2, max_sleep_time) + return status == "DONE" + + +class NixlTransferAgent(BaseTransferAgent): + """Python-based TransferAgent using the NIXL library.""" + + def __init__(self, name: str, use_prog_thread: bool, num_workers: int = 1): + self.name = name + agent_config = nixl_agent_config( + enable_prog_thread=use_prog_thread, + backends=["UCX"], + num_threads=num_workers, + ) + self.agent = nixl_agent(name, agent_config) + + def register_memory(self, descs: RegMemoryDescs): + reg_descs = self.agent.get_reg_descs(descs.descs, descs.type) + self.agent.register_memory(reg_descs) + + def deregister_memory(self, descs: RegMemoryDescs): + self.agent.deregister_memory(descs.descs, descs.type) + + def load_remote_agent(self, name: str, agent_desc: bytes): + self.agent.add_remote_agent(agent_desc) + + def get_local_agent_desc(self) -> bytes: + return self.agent.get_agent_metadata() + + def invalidate_remote_agent(self, name: str): + self.agent.remove_remote_agent(name) + + def check_remote_descs(self, name: str, memory_descs: list[int]) -> bool: + raise NotImplementedError("check_remote_descs is not implemented.") + + def notify_sync_message(self, name: str, sync_message: str): + raise NotImplementedError("notify_sync_message is not implemented.") + + @nvtx_range("NixlTransferAgent.submit_transfer_requests") + def submit_transfer_requests(self, request: TransferRequest) -> BaseTransferStatus: + src_xfer_descs = self.agent.get_xfer_descs(request.src_descs.descs, request.src_descs.type) + dst_xfer_descs = self.agent.get_xfer_descs(request.dst_descs.descs, request.dst_descs.type) + handle = self.agent.initialize_xfer( + request.op, + src_xfer_descs, + dst_xfer_descs, + request.remote_name, + request.sync_message, + ) + status = self.agent.transfer(handle) + assert status != "ERROR", "Transfer failed during initialization." + return NixlTransferStatus(self.agent, handle) diff --git a/tensorrt_llm/_torch/disaggregation/nixl/agent.py b/tensorrt_llm/_torch/disaggregation/nixl/agent.py index e165f9575ca..a4776096007 100644 --- a/tensorrt_llm/_torch/disaggregation/nixl/agent.py +++ b/tensorrt_llm/_torch/disaggregation/nixl/agent.py @@ -1,3 +1,5 @@ +from tensorrt_llm import logger + """NIXL Transfer Agent implementations. This module provides two implementations: @@ -8,246 +10,16 @@ so trtllm can still function normally even without NIXL dependencies. """ -import time - -from tensorrt_llm._utils import nvtx_range - -# Import base classes for type compatibility -from ..base.agent import BaseTransferAgent, RegMemoryDescs, TransferRequest, TransferStatus - -# Try to import the standalone tensorrt_llm_transfer_agent_binding module -# Located at tensorrt_llm/ (same level as bindings.so) -_AGENT_BINDING_AVAILABLE = False try: - import tensorrt_llm.tensorrt_llm_transfer_agent_binding as _agent_binding # noqa: E402 + # Try to import the standalone tensorrt_llm_transfer_agent_binding module + # Located at tensorrt_llm/ (same level as bindings.so) + from ._agent_cpp import NixlTransferAgent, NixlTransferStatus - _AGENT_BINDING_AVAILABLE = True + logger.info("Using C++ NIXL TransferAgent") - # Import from standalone module - BaseAgentConfig = _agent_binding.BaseAgentConfig - CppNixlTransferAgent = _agent_binding.NixlTransferAgent - CppNixlTransferStatus = _agent_binding.NixlTransferStatus - MemoryType = _agent_binding.MemoryType - MemoryDescs = _agent_binding.MemoryDescs - AgentDesc = _agent_binding.AgentDesc - TransferState = _agent_binding.TransferState except ImportError: - # tensorrt_llm_transfer_agent_binding not available, will fall back to Python nixl or raise error - pass - - -def is_transfer_agent_binding_available() -> bool: - """Check if the standalone tensorrt_llm_transfer_agent_binding module is available.""" - return _AGENT_BINDING_AVAILABLE - - -class BindingsNixlTransferStatus(TransferStatus): - """TransferStatus wrapper using C++ bindings with GIL release.""" - - def __init__(self, cpp_status): - self._cpp_status = cpp_status - - def is_completed(self) -> bool: - """Check if transfer is completed (releases GIL).""" - return self._cpp_status.is_completed() - - @nvtx_range("BindingsNixlTransferStatus.wait") - def wait(self) -> bool: - """Wait for transfer to complete (releases GIL).""" - return self._cpp_status.wait() == TransferState.SUCCESS - - -class BindingsNixlTransferAgent(BaseTransferAgent): - """NixlTransferAgent using C++ bindings with GIL release support. - - This implementation uses the standalone nixl_bindings C++ module which releases - the GIL during blocking operations like wait(). - - The nixl_bindings module is independent from the main trtllm bindings, - so trtllm can function normally even without NIXL. - """ - - def __init__(self, name: str, use_prog_thread: bool = True, num_workers: int = 1): - if not _AGENT_BINDING_AVAILABLE: - raise ImportError( - "tensorrt_llm_transfer_agent_binding module is not available. " - "Please build with NIXL_ROOT set to enable NIXL support." - ) - config = BaseAgentConfig( - name, - use_prog_thread, - multi_thread=False, - use_listen_thread=False, - num_workers=num_workers, - ) - self._cpp_agent = CppNixlTransferAgent(config) - self.name = name - - def register_memory(self, descs: RegMemoryDescs): - """Register memory regions.""" - cpp_descs = self._convert_reg_memory_descs(descs) - self._cpp_agent.register_memory(cpp_descs) - - def deregister_memory(self, descs: RegMemoryDescs): - """Deregister memory regions.""" - cpp_descs = self._convert_reg_memory_descs(descs) - self._cpp_agent.deregister_memory(cpp_descs) - - def load_remote_agent(self, name: str, agent_desc: bytes): - """Load a remote agent by its descriptor (bytes).""" - # AgentDesc expects std::string which can hold binary data - desc_str = agent_desc if isinstance(agent_desc, bytes) else agent_desc.encode() - cpp_desc = AgentDesc(desc_str) - self._cpp_agent.load_remote_agent(name, cpp_desc) - - def load_remote_agent_by_connection(self, name: str, connection_info: str): - """Load a remote agent by connection info.""" - self._cpp_agent.load_remote_agent_by_connection(name, connection_info) - - def get_local_agent_desc(self) -> bytes: - """Get the local agent descriptor as bytes.""" - agent_desc = self._cpp_agent.get_local_agent_desc() - return agent_desc.backend_agent_desc # Returns bytes - - def get_local_connection_info(self) -> str: - """Get the local connection info.""" - return self._cpp_agent.get_local_connection_info() - - def invalidate_remote_agent(self, name: str): - """Invalidate a remote agent.""" - self._cpp_agent.invalidate_remote_agent(name) - - def check_remote_descs(self, name: str, memory_descs: MemoryDescs) -> bool: - """Check if remote descriptors are available. - - memory_descs should be C++ MemoryDescs type. - """ - return self._cpp_agent.check_remote_descs(name, memory_descs) + from ._agent_python import NixlTransferAgent, NixlTransferStatus - def notify_sync_message(self, name: str, sync_message: str): - """Send a sync message to a remote agent.""" - self._cpp_agent.notify_sync_message(name, sync_message) + logger.info("Using Python NIXL TransferAgent") - def get_notified_sync_messages(self): - """Get notified sync messages.""" - return self._cpp_agent.get_notified_sync_messages() - - @nvtx_range("BindingsNixlTransferAgent.submit_transfer_requests") - def submit_transfer_requests(self, request: TransferRequest) -> TransferStatus: - """Submit transfer requests and return status. - - request should be a C++ TransferRequest (from tensorrt_llm_transfer_agent_binding). - """ - cpp_status = self._cpp_agent.submit_transfer_requests(request) - return BindingsNixlTransferStatus(cpp_status) - - def _convert_reg_memory_descs(self, descs: RegMemoryDescs) -> "MemoryDescs": - """Convert Python RegMemoryDescs to C++ MemoryDescs. - - RegMemoryDescs.descs is List[Tuple[int, int, int, str]] = (ptr, size, device_id, name) - Extract first 3 elements for C++ batch constructor. - """ - mem_type = self._convert_memory_type(descs.type) - # Extract (ptr, size, device_id) from 4-tuple, discard name - tuples = [(d[0], d[1], d[2]) for d in descs.descs] - return MemoryDescs(mem_type, tuples) - - def _convert_memory_type(self, py_type: str) -> "MemoryType": - """Convert Python memory type string to C++ MemoryType.""" - type_map = { - "DRAM": MemoryType.DRAM, - "VRAM": MemoryType.VRAM, - "GPU": MemoryType.VRAM, - "BLK": MemoryType.BLK, - "OBJ": MemoryType.OBJ, - "FILE": MemoryType.FILE, - } - return type_map.get(py_type.upper(), MemoryType.VRAM) - - -# For backward compatibility, also keep the Python nixl-based implementation -NixlTransferAgent = None -NixlTransferStatus = None - -try: - from nixl import nixl_agent, nixl_agent_config, nixl_xfer_handle # noqa: E402 - - class NixlTransferStatus(TransferStatus): - """TransferStatus using Python nixl library.""" - - def __init__(self, agent: nixl_agent, handle: nixl_xfer_handle): - self.agent = agent - self.handle = handle - - def is_completed(self): - status = self.agent.check_xfer_state(self.handle) - return status == "DONE" - - def wait(self): - status = "PROC" - sleep_time = 0.0001 # 0.1ms - max_sleep_time = 0.01 # 10ms - while status == "PROC": - status = self.agent.check_xfer_state(self.handle) - if status == "ERROR": - return False # transfer failed - # sleep(0.1) - # sleep to release GIL - time.sleep(sleep_time) - sleep_time = min(sleep_time * 2, max_sleep_time) - return True - - class NixlTransferAgent(BaseTransferAgent): - """NixlTransferAgent using Python nixl library.""" - - def __init__(self, name: str, use_prog_thread: bool, num_workers: int = 1): - self.name = name - agent_config = nixl_agent_config( - enable_prog_thread=use_prog_thread, backends=["UCX"], num_threads=num_workers - ) - self.agent = nixl_agent(name, agent_config) - - def register_memory(self, descs: RegMemoryDescs): - reg_descs = self.agent.get_reg_descs(descs.descs, descs.type) - self.agent.register_memory(reg_descs) - - def deregister_memory(self, descs: RegMemoryDescs): - self.agent.deregister_memory(descs.descs, descs.type) - - def load_remote_agent(self, name: str, agent_desc: bytes): - self.agent.add_remote_agent(agent_desc) - - def get_local_agent_desc(self): - return self.agent.get_agent_metadata() - - def invalidate_remote_agent(self, name: str): - self.agent.remove_remote_agent(name) - - def check_remote_descs(self, name: str, memory_descs: list[int]) -> bool: - raise NotImplementedError - - def notify_sync_message(self, name: str, sync_message: str): - raise NotImplementedError - - @nvtx_range("NixlTransferAgent.submit_transfer_requests") - def submit_transfer_requests(self, request: TransferRequest) -> TransferStatus: - src_xfer_descs = self.agent.get_xfer_descs( - request.src_descs.descs, request.src_descs.type - ) - dst_xfer_descs = self.agent.get_xfer_descs( - request.dst_descs.descs, request.dst_descs.type - ) - handle = self.agent.initialize_xfer( - request.op, - src_xfer_descs, - dst_xfer_descs, - request.remote_name, - request.sync_message, - ) - status = self.agent.transfer(handle) - assert status != "ERROR" - return NixlTransferStatus(self.agent, handle) - -except ImportError: - # nixl library not available - pass +__all__ = ["NixlTransferStatus", "NixlTransferAgent"]