diff --git a/tensorrt_llm/_torch/disaggregation/__init__.py b/tensorrt_llm/_torch/disaggregation/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tensorrt_llm/_torch/disaggregation/base/__init__.py b/tensorrt_llm/_torch/disaggregation/base/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tensorrt_llm/_torch/disaggregation/base/agent.py b/tensorrt_llm/_torch/disaggregation/base/agent.py new file mode 100644 index 00000000000..bc9f10a2a72 --- /dev/null +++ b/tensorrt_llm/_torch/disaggregation/base/agent.py @@ -0,0 +1,78 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Enum, List, Tuple, Union + + +# Common Enumerations +class TransferOp(Enum): + READ = "READ" + WRITE = "WRITE" + + +class MemoryType(Enum): + DRAM = "DRAM" + VRAM = "VRAM" + BLK = "BLK" + OBJ = "OBJ" + FILE = "FILE" + + +# Common Data Structures +@dataclass +class MemoryDesc: + ptr: int + size: int + device_id: int + + +@dataclass +class MemoryDescs: + type: str + descs: List[Union[Tuple[int, int, int], MemoryDesc]] + + +@dataclass +class TransferRequest: + op: TransferOp + src_descs: MemoryDescs + dst_descs: MemoryDescs + remote_name: str + sync_message: str + + +class BaseTransferStatus(ABC): + """Abstract base class for transfer status.""" + + @abstractmethod + def is_completed(self) -> bool: ... + + @abstractmethod + def wait(self, timeout: float | None = None) -> None: ... + + +class BaseTransferAgent(ABC): + """Abstract base class for transfer agents.""" + + @abstractmethod + def register_memory(self, descs: MemoryDescs) -> None: ... + + @abstractmethod + def deregister_memory(self, descs: MemoryDescs) -> None: ... + + @abstractmethod + def load_remote_agent(self, name: str, agent_desc: str) -> None: ... + + @abstractmethod + def get_local_agent_desc(self) -> str: ... + + @abstractmethod + def invalidate_remote_agent(self, name: str) -> None: ... + + @abstractmethod + def submit_transfer_requests(self, request: TransferRequest) -> BaseTransferStatus: ... + + @abstractmethod + def notify_sync_message(self, name: str, sync_message: str) -> None: ... + + @abstractmethod + def check_remote_descs(self, name: str, memory_descs: List[int]) -> bool: ... diff --git a/tensorrt_llm/_torch/disaggregation/base/kv_transfer.py b/tensorrt_llm/_torch/disaggregation/base/kv_transfer.py new file mode 100644 index 00000000000..e45f4856f0e --- /dev/null +++ b/tensorrt_llm/_torch/disaggregation/base/kv_transfer.py @@ -0,0 +1,144 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from enum import Enum +from typing import List, Optional + +from tensorrt_llm import DisaggregatedParams + + +@dataclass +class KVSlice: + """Supports transmitting only part of the request cache, e.g, chunks or layers.""" + + start_token_idx: Optional[int] = None + end_token_idx: Optional[int] = None + start_layer: Optional[int] = None + end_layer: Optional[int] = None + blocks: List[int] = field(default_factory=list) + is_last_slice: bool = False + + +class SessionStatus(Enum): + """Status of a transfer session.""" + + INIT = "INIT" + READY = "READY" + TRANSFERRING = "TRANSFERRING" + TRANSFERRED = "TRANSFERRED" + AUX_TRANSFERRED = "AUX_TRANSFERRED" + COMPLETED = "COMPLETED" + CANCELED = "CANCELED" + ERROR = "ERROR" + + +TaskIdType = int + + +@dataclass +class SessionState: + """State of a transfer session.""" + + status: SessionStatus + finished_tasks: List[TaskIdType] + + +@dataclass +class SessionArgsBase: + """Base arguments for transfer sessions.""" + + params: DisaggregatedParams + + +class SenderBase(ABC): + """Base class for sending KV cache data.""" + + ... + + +class ReceiverBase(ABC): + """Base class for receiving KV cache data.""" + + ... + + +class TxSessionBase(ABC): + def __init__(self, sender: SenderBase, args: SessionArgsBase): + """ + Initializes the transmission session. + :param sender: The sender instance responsible for sending data. + :param args: The session arguments. + """ + self._base_args = args + + @property + @abstractmethod + def state(self) -> SessionState: + """ + Returns the current state of the session. + """ + ... + + @abstractmethod + def poll_task(self, id: TaskIdType) -> SessionStatus: + """ + Polls the status of a specific task by its ID. + :param id: The task ID to poll. + """ + ... + + @abstractmethod + def send(self, slice: KVSlice) -> TaskIdType: + """ + Sends a slice of KV cache data and returns the task ID. + :param slice: The KV slice to send. + """ + ... + + @property + @abstractmethod + def exception(self) -> Optional[Exception]: + """ + Returns any exception that occurred during the session. + """ + ... + + +class RxSessionBase(ABC): + def __init__(self, receiver: ReceiverBase, args: SessionArgsBase): + """ + Initializes the reception session. + :param receiver: The receiver instance responsible for receiving data. + """ + self._base_args = args + + @property + @abstractmethod + def state(self) -> SessionState: + """ + Returns the current state of the session. + """ + ... + + @abstractmethod + def poll_task(self, id: TaskIdType) -> SessionStatus: + """ + Polls the status of a specific task by its ID. + :param id: The task ID to poll. + """ + ... + + @abstractmethod + def receive(self, slice: KVSlice) -> TaskIdType: + """ + Receives a slice of KV cache data and returns the task ID. + :param slice: The KV slice to receive. + """ + ... + + @property + @abstractmethod + def exception(self) -> Optional[Exception]: + """Returns any exception that occurred during the session.""" + ... diff --git a/tensorrt_llm/_torch/disaggregation/native/__init__.py b/tensorrt_llm/_torch/disaggregation/native/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tensorrt_llm/_torch/disaggregation/native/messenger.py b/tensorrt_llm/_torch/disaggregation/native/messenger.py new file mode 100644 index 00000000000..ed72cc04446 --- /dev/null +++ b/tensorrt_llm/_torch/disaggregation/native/messenger.py @@ -0,0 +1,228 @@ +from abc import ABC, abstractmethod +from threading import Event, Lock, Thread +from typing import Callable, Optional + +import zmq + +from tensorrt_llm import logger +from tensorrt_llm._torch.disaggregation.native.utils import get_local_ip + + +class MessengerInterface(ABC): + """ + Abstract base class for messenger implementations. + """ + + @abstractmethod + def start(self): + """ + Start the messenger service. + """ + ... + + @abstractmethod + def send(self, messages: list[bytes], recipient: Optional[bytes] = None): + """ + Send messages to a recipient. + :param messages: List of byte messages to send. + :param recipient: Optional recipient identifier. + """ + ... + + @abstractmethod + def send_encoded(self, *messages, encoding: str = "ascii"): + """ + Send messages after encoding them. + :param messages: Messages to send. + :param encoding: Encoding format. + """ + ... + + @abstractmethod + def receive(self) -> list[bytes]: + """ + Receive messages. + :return: List of byte messages received. + """ + ... + + @abstractmethod + def start_listener(self, on_message: Callable[[list[bytes]], None]): + """ + Start a listener thread to handle incoming messages. + :param on_message: Callback function to process received messages. + """ + ... + + @abstractmethod + def stop(self): + """ + Stop the messenger service. + """ + ... + + @property + @abstractmethod + def endpoint(self) -> str: + """ + Get the endpoint of the messenger. + :return: Endpoint string. + """ + ... + + +def decode_message( + message: list[bytes], encoding: str = "ascii", err_mode: str = "strict" +) -> tuple: + if not isinstance(message, list) or not all(isinstance(m, bytes) for m in message): + raise ValueError("Input must be a list of bytes") + return tuple(m.decode(encoding, errors=err_mode) for m in message) + + +class ZMQMessenger(MessengerInterface): + SOCKET_MODES = { + "ROUTER": zmq.ROUTER, # Handles multiple connections and routes messages by address. + "DEALER": zmq.DEALER, # Load balances outgoing messages and receives replies fairly. + "REQ": zmq.REQ, # Sends requests and waits for replies (synchronous). + "REP": zmq.REP, # Receives requests and sends replies (synchronous). + } + + def __init__(self, mode: str, endpoint: Optional[str] = None): + self._context = zmq.Context() + self._mode = mode + self._socket = self._context.socket(self.SOCKET_MODES[mode]) + self._endpoint: Optional[str] = None + self._lock = Lock() + self._closed = False + self._stop_event = Event() + self._listener_thread: Optional[Thread] = None + self._initialize_control_sockets() + + if endpoint is None: + endpoint = f"tcp://{get_local_ip()}:*" + + if mode in ["ROUTER", "REP"]: + self._socket.bind(endpoint) + self._endpoint = self._socket.getsockopt_string(zmq.LAST_ENDPOINT) + elif mode in ["DEALER", "REQ"]: + self._socket.connect(endpoint) + self._endpoint = endpoint + + logger.info(f"Initialized ZMQMessenger(mode={mode}, endpoint={self._endpoint})") + + def _initialize_control_sockets(self): + self._control_socket = self._context.socket(zmq.PAIR) + self._internal_socket = self._context.socket(zmq.PAIR) + inproc_endpoint = "inproc://stop_listener" + self._control_socket.bind(inproc_endpoint) + self._internal_socket.connect(inproc_endpoint) + + def start(self): + pass + + def send(self, messages: list[bytes], recipient: Optional[bytes] = None): + if recipient: + self._socket.send_multipart([recipient] + messages) + else: + self._socket.send_multipart(messages) + + def send_encoded(self, *messages, encoding: str = "ascii"): + encoded_messages = [str(message).encode(encoding) for message in messages] + self.send(encoded_messages) + + def receive(self) -> list[bytes]: + return self._socket.recv_multipart() + + def start_listener( + self, + on_message: Callable[[list[bytes]], None], + on_error: Optional[Callable[[Exception], None]] = None, + ): + assert self._mode in ["ROUTER", "REP"], ( + "Listener can only be started in ROUTER or REP modes" + ) + if self._listener_thread and self._listener_thread.is_alive(): + raise RuntimeError("Listener already running") + + def listener(): + poller = zmq.Poller() + poller.register(self._socket, zmq.POLLIN) + poller.register(self._control_socket, zmq.POLLIN) + + while not self._stop_event.is_set(): + events = dict(poller.poll(timeout=100)) + try: + if self._control_socket in events: + self._stop_event.set() + elif self._socket in events: + messages = self.receive() + try: + persist = on_message(messages) + if persist is False: + self._stop_event.set() + except Exception as e: + logger.error(f"Error in on_message callback: {e}") + if on_error: + on_error(e) + else: + self._stop_event.set() + except zmq.ZMQError as e: + logger.error(f"ZMQ Error in listener: {e}") + if on_error: + on_error(e) + break + except Exception as e: + logger.error(f"Error in listener: {e}") + if on_error: + on_error(e) + break + + self._stop_event.set() + + self._listener_thread = Thread(target=listener, daemon=True) + self._listener_thread.start() + + def stop(self, timeout=5): + def _close_socket(socket: zmq.Socket): + try: + if not socket.closed: + socket.close() + except Exception as e: + logger.error(f"Error closing socket: {e}") + + with self._lock: + if self._closed: + return + self._closed = True + logger.debug("Stopping ZMQMessenger...") + + self._stop_event.set() + self._internal_socket.send(b"STOP") + if self._listener_thread: + self._listener_thread.join(timeout) + if self._listener_thread.is_alive(): + logger.warning("Listener thread did not terminate within timeout") + + _close_socket(self._socket) + _close_socket(self._internal_socket) + _close_socket(self._control_socket) + + try: + if self._context.closed: + self._context.term() + except Exception as e: + logger.error(f"Error terminating ZMQ context: {e}") + + @property + def endpoint(self) -> str: + assert self._endpoint is not None + return self._endpoint + + def __del__(self): + self.stop() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.stop() diff --git a/tensorrt_llm/_torch/disaggregation/native/utils.py b/tensorrt_llm/_torch/disaggregation/native/utils.py new file mode 100644 index 00000000000..b4d0f5a7391 --- /dev/null +++ b/tensorrt_llm/_torch/disaggregation/native/utils.py @@ -0,0 +1,34 @@ +def get_local_ip() -> str: + try: + import socket + + with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s: + s.connect(("8.8.8.8", 80)) + ip = s.getsockname()[0] + if not ip.startswith("127."): + return ip + except OSError: + pass + + try: + import netifaces + + for iface in netifaces.interfaces(): + addrs = netifaces.ifaddresses(iface) + if netifaces.AF_INET in addrs: + for addr in addrs[netifaces.AF_INET]: + ip = addr.get("addr", "") + if not ip.startswith("127.") and not ip.startswith("169.254"): + return ip + except Exception: + pass + + try: + hostname = socket.gethostname() + ip = socket.gethostbyname(hostname) + if not ip.startswith("127."): + return ip + except OSError: + pass + + return "127.0.0.1" diff --git a/tensorrt_llm/_torch/disaggregation/nixl/__init__.py b/tensorrt_llm/_torch/disaggregation/nixl/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tensorrt_llm/_torch/disaggregation/nixl/_agent_cpp.py b/tensorrt_llm/_torch/disaggregation/nixl/_agent_cpp.py new file mode 100644 index 00000000000..214fe929838 --- /dev/null +++ b/tensorrt_llm/_torch/disaggregation/nixl/_agent_cpp.py @@ -0,0 +1,108 @@ +from tensorrt_llm._utils import nvtx_range +from tensorrt_llm.tensorrt_llm_transfer_agent_binding import ( + AgentDesc, + BaseAgentConfig, + MemoryDescs, + MemoryType, + TransferState, +) +from tensorrt_llm.tensorrt_llm_transfer_agent_binding import ( + NixlTransferAgent as CppNixlTransferAgent, +) +from tensorrt_llm.tensorrt_llm_transfer_agent_binding import ( + NixlTransferStatus as CppNixlTransferStatus, +) + +from ..base.agent import BaseTransferAgent, BaseTransferStatus, RegMemoryDescs, TransferRequest + + +class NixlTransferStatus(BaseTransferStatus): + def __init__(self, cpp_status: CppNixlTransferStatus): + self._cpp_status = cpp_status + + def is_completed(self) -> bool: + """Check if transfer is completed (releases GIL).""" + return self._cpp_status.is_completed() + + @nvtx_range("NixlTransferStatus.wait") + def wait(self, timeout: float = None) -> bool: + """Wait for transfer to complete (releases GIL).""" + return self._cpp_status.wait() == TransferState.SUCCESS + + +class NixlTransferAgent(BaseTransferAgent): + """NixlTransferAgent using C++ bindings with GIL release support. + + This implementation uses the standalone nixl_bindings C++ module which releases + the GIL during blocking operations like wait(). + + The nixl_bindings module is independent from the main trtllm bindings, + so trtllm can function normally even without NIXL. + """ + + def __init__(self, name: str, use_prog_thread: bool = True, num_workers: int = 1): + config = BaseAgentConfig( + name=name, + use_prog_thread=use_prog_thread, + multi_thread=False, + use_listen_thread=False, + num_workers=num_workers, + ) + self._cpp_agent = CppNixlTransferAgent(config) + self.name = name + + def register_memory(self, descs: RegMemoryDescs): + cpp_descs = self._convert_reg_memory_descs(descs) + self._cpp_agent.register_memory(cpp_descs) + + def deregister_memory(self, descs: RegMemoryDescs): + cpp_descs = self._convert_reg_memory_descs(descs) + self._cpp_agent.deregister_memory(cpp_descs) + + def load_remote_agent(self, name: str, agent_desc: bytes): + desc_str = agent_desc if isinstance(agent_desc, bytes) else agent_desc.encode() + cpp_desc = AgentDesc(desc_str) + self._cpp_agent.load_remote_agent(name, cpp_desc) + + def load_remote_agent_by_connection(self, name: str, connection_info: str): + self._cpp_agent.load_remote_agent_by_connection(name, connection_info) + + def get_local_agent_desc(self) -> bytes: + agent_desc = self._cpp_agent.get_local_agent_desc() + return agent_desc.backend_agent_desc + + def get_local_connection_info(self) -> str: + return self._cpp_agent.get_local_connection_info() + + def invalidate_remote_agent(self, name: str): + self._cpp_agent.invalidate_remote_agent(name) + + def check_remote_descs(self, name: str, memory_descs: MemoryDescs) -> bool: + return self._cpp_agent.check_remote_descs(name, memory_descs) + + def notify_sync_message(self, name: str, sync_message: str): + self._cpp_agent.notify_sync_message(name, sync_message) + + def get_notified_sync_messages(self): + return self._cpp_agent.get_notified_sync_messages() + + @nvtx_range("BindingsNixlTransferAgent.submit_transfer_requests") + def submit_transfer_requests(self, request: TransferRequest) -> BaseTransferStatus: + cpp_status = self._cpp_agent.submit_transfer_requests(request) + return NixlTransferStatus(cpp_status) + + def _convert_reg_memory_descs(self, descs: RegMemoryDescs) -> "MemoryDescs": + mem_type = self._convert_memory_type(descs.type) + tuples = [(d[0], d[1], d[2]) for d in descs.descs] # Extract (ptr, size, device_id) + return MemoryDescs(mem_type, tuples) + + def _convert_memory_type(self, py_type: str) -> "MemoryType": + type_map = { + "DRAM": MemoryType.DRAM, + "VRAM": MemoryType.VRAM, + "GPU": MemoryType.VRAM, + "BLK": MemoryType.BLK, + "OBJ": MemoryType.OBJ, + "FILE": MemoryType.FILE, + } + return type_map.get(py_type.upper(), MemoryType.VRAM) diff --git a/tensorrt_llm/_torch/disaggregation/nixl/_agent_python.py b/tensorrt_llm/_torch/disaggregation/nixl/_agent_python.py new file mode 100644 index 00000000000..2466e924170 --- /dev/null +++ b/tensorrt_llm/_torch/disaggregation/nixl/_agent_python.py @@ -0,0 +1,81 @@ +import time + +from nixl import nixl_agent, nixl_agent_config, nixl_xfer_handle + +from tensorrt_llm._utils import nvtx_range + +from ..base.agent import BaseTransferAgent, BaseTransferStatus, RegMemoryDescs, TransferRequest + + +class NixlTransferStatus(BaseTransferStatus): + """TransferStatus using the Python NIXL library.""" + + def __init__(self, agent: nixl_agent, handle: nixl_xfer_handle): + self.agent = agent + self.handle = handle + + def is_completed(self) -> bool: + status = self.agent.check_xfer_state(self.handle) + return status == "DONE" + + def wait(self) -> bool: + status = "PROC" + sleep_time = 0.0001 # 0.1ms + max_sleep_time = 0.01 # 10ms + while status == "PROC": + status = self.agent.check_xfer_state(self.handle) + if status == "ERROR": + return False # Transfer failed + time.sleep(sleep_time) # Sleep to release GIL + sleep_time = min(sleep_time * 2, max_sleep_time) + return status == "DONE" + + +class NixlTransferAgent(BaseTransferAgent): + """Python-based TransferAgent using the NIXL library.""" + + def __init__(self, name: str, use_prog_thread: bool, num_workers: int = 1): + self.name = name + agent_config = nixl_agent_config( + enable_prog_thread=use_prog_thread, + backends=["UCX"], + num_threads=num_workers, + ) + self.agent = nixl_agent(name, agent_config) + + def register_memory(self, descs: RegMemoryDescs): + reg_descs = self.agent.get_reg_descs(descs.descs, descs.type) + self.agent.register_memory(reg_descs) + + def deregister_memory(self, descs: RegMemoryDescs): + self.agent.deregister_memory(descs.descs, descs.type) + + def load_remote_agent(self, name: str, agent_desc: bytes): + self.agent.add_remote_agent(agent_desc) + + def get_local_agent_desc(self) -> bytes: + return self.agent.get_agent_metadata() + + def invalidate_remote_agent(self, name: str): + self.agent.remove_remote_agent(name) + + def check_remote_descs(self, name: str, memory_descs: list[int]) -> bool: + raise NotImplementedError("check_remote_descs is not implemented.") + + def notify_sync_message(self, name: str, sync_message: str): + raise NotImplementedError("notify_sync_message is not implemented.") + + @nvtx_range("NixlTransferAgent.submit_transfer_requests") + def submit_transfer_requests(self, request: TransferRequest) -> BaseTransferStatus: + src_xfer_descs = self.agent.get_xfer_descs(request.src_descs.descs, request.src_descs.type) + dst_xfer_descs = self.agent.get_xfer_descs(request.dst_descs.descs, request.dst_descs.type) + handle = self.agent.initialize_xfer( + request.op, + src_xfer_descs, + dst_xfer_descs, + request.remote_name, + request.sync_message, + ) + status = self.agent.transfer(handle) + assert status != "ERROR", "Transfer failed during initialization." + return NixlTransferStatus(self.agent, handle) diff --git a/tensorrt_llm/_torch/disaggregation/nixl/agent.py b/tensorrt_llm/_torch/disaggregation/nixl/agent.py new file mode 100644 index 00000000000..a4776096007 --- /dev/null +++ b/tensorrt_llm/_torch/disaggregation/nixl/agent.py @@ -0,0 +1,25 @@ +from tensorrt_llm import logger + +"""NIXL Transfer Agent implementations. + +This module provides two implementations: +1. BindingsNixlTransferAgent - Uses the standalone nixl_bindings C++ module with GIL release support +2. NixlTransferAgent - Uses the Python nixl library directly (fallback) + +The standalone nixl_bindings module is separate from the main trtllm bindings, +so trtllm can still function normally even without NIXL dependencies. +""" + +try: + # Try to import the standalone tensorrt_llm_transfer_agent_binding module + # Located at tensorrt_llm/ (same level as bindings.so) + from ._agent_cpp import NixlTransferAgent, NixlTransferStatus + + logger.info("Using C++ NIXL TransferAgent") + +except ImportError: + from ._agent_python import NixlTransferAgent, NixlTransferStatus + + logger.info("Using Python NIXL TransferAgent") + +__all__ = ["NixlTransferStatus", "NixlTransferAgent"] diff --git a/tests/unittest/disaggregated/test_messenger.py b/tests/unittest/disaggregated/test_messenger.py new file mode 100644 index 00000000000..94e052d531e --- /dev/null +++ b/tests/unittest/disaggregated/test_messenger.py @@ -0,0 +1,127 @@ +import socket +import time +import unittest + +import pytest +from parameterized import parameterized + +from tensorrt_llm._torch.disaggregation.native.messenger import ZMQMessenger, decode_message +from tensorrt_llm._torch.disaggregation.native.utils import get_local_ip + +TEST_CASES = [ + { + "name": "valid_message", + "message": [b"hello", b"world"], + "encoding": "utf-8", + "err_mode": "strict", + "expected": ("hello", "world"), + "raises": None, + }, + { + "name": "invalid_input", + "message": ["hello", b"world"], + "encoding": "utf-8", + "err_mode": "strict", + "expected": None, + "raises": ValueError, + }, + { + "name": "decoding_error", + "message": [b"\xff"], + "encoding": "utf-8", + "err_mode": "strict", + "expected": None, + "raises": UnicodeDecodeError, + }, + { + "name": "decoding_with_ignore", + "message": [b"\xff"], + "encoding": "utf-8", + "err_mode": "ignore", + "expected": ("",), + "raises": None, + }, +] + + +class TestDecodeMessage(unittest.TestCase): + @parameterized.expand([(case["name"], case) for case in TEST_CASES]) + def test_decode_message(self, name, case): + message = case["message"] + encoding = case["encoding"] + err_mode = case["err_mode"] + expected = case["expected"] + raises = case["raises"] + + if raises: + with self.assertRaises(raises): + decode_message(message, encoding=encoding, err_mode=err_mode) + else: + decoded = decode_message(message, encoding=encoding, err_mode=err_mode) + self.assertEqual(decoded, expected) + + +@pytest.fixture +def dynamic_endpoint(): + """Fixture to dynamically generate an available endpoint with a free port.""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) # Bind to an available port provided by the OS + port = s.getsockname()[1] + return f"tcp://{get_local_ip()}:{port}" + + +@pytest.fixture +def create_messenger_pair(dynamic_endpoint): + def _create_messenger_pair(mode1, mode2): + messenger1 = ZMQMessenger( + mode1, endpoint=dynamic_endpoint if mode1 in ["ROUTER", "REP"] else None + ) + messenger2 = ZMQMessenger( + mode2, endpoint=dynamic_endpoint if mode2 in ["DEALER", "REQ"] else None + ) + return messenger1, messenger2 + + yield _create_messenger_pair + + +def test_router_dealer(create_messenger_pair): + """Test ROUTER and DEALER communication.""" + router, dealer = create_messenger_pair("ROUTER", "DEALER") + + received_messages = [] + + def on_message(messages): + received_messages.extend(messages) + + router.start_listener(on_message) + + dealer.send([b"Hello, ROUTER!"]) + + time.sleep(0.1) + + assert len(received_messages) > 0 + assert b"Hello, ROUTER!" in received_messages + + router.stop() + dealer.stop() + + +def test_req_rep(create_messenger_pair): + """Test REQ and REP communication.""" + rep, req = create_messenger_pair("REP", "REQ") + + def on_message(messages): + rep.send(messages) + + rep.start_listener(on_message) + + req.send([b"Hello, REP!"]) + response = req.receive() + assert response == [b"Hello, REP!"] + + req.stop() + rep.stop() + + +if __name__ == "__main__": + unittest.main()