11from abc import ABC , abstractmethod
22from dataclasses import dataclass
3- from typing import List , Tuple , Union
3+ from typing import Enum , List , Tuple , Union
4+
5+ from tensorrt_llm .logger import logger
46
57# Try to import C++ bindings for zero-copy performance
6- _CPP_BINDING_AVAILABLE = False
78try :
8- import tensorrt_llm .tensorrt_llm_transfer_agent_binding as _cpp_binding
9+ from tensorrt_llm .tensorrt_llm_transfer_agent_binding import (
10+ BaseTransferAgent ,
11+ MemoryDesc ,
12+ MemoryDescs ,
13+ MemoryType ,
14+ TransferOp ,
15+ TransferRequest ,
16+ TransferStatus ,
17+ )
918
1019 _CPP_BINDING_AVAILABLE = True
11- # Use C++ types directly when available
12- MemoryType = _cpp_binding .MemoryType
13- TransferOp = _cpp_binding .TransferOp
14- MemoryDesc = _cpp_binding .MemoryDesc
15- MemoryDescs = _cpp_binding .MemoryDescs
16- TransferRequest = _cpp_binding .TransferRequest
17- TransferStatus = _cpp_binding .TransferStatus
18- BaseTransferAgent = _cpp_binding .BaseTransferAgent
1920except ImportError :
2021 _CPP_BINDING_AVAILABLE = False
22+ logger .warning (
23+ "C++ transfer agent bindings not available. "
24+ "Falling back to Python implementations which may have lower performance."
25+ )
2126
2227
2328def is_cpp_binding_available () -> bool :
@@ -28,11 +33,11 @@ def is_cpp_binding_available() -> bool:
2833# Fallback Python implementations when C++ bindings not available
2934if not _CPP_BINDING_AVAILABLE :
3035
31- class TransferOp :
36+ class TransferOp ( Enum ) :
3237 READ = "READ"
3338 WRITE = "WRITE"
3439
35- class MemoryType :
40+ class MemoryType ( Enum ) :
3641 DRAM = "DRAM"
3742 VRAM = "VRAM"
3843 BLK = "BLK"
@@ -67,28 +72,52 @@ def wait(self, timeout: float | None = None) -> None: ...
6772
6873 class BaseTransferAgent (ABC ):
6974 @abstractmethod
70- def register_memory (self , descs : MemoryDescs ) -> None : ...
75+ def register_memory (self , descs : MemoryDescs ) -> None :
76+ """Register a set of memory descriptors on the agent."""
77+ ...
7178
7279 @abstractmethod
73- def deregister_memory (self , descs : MemoryDescs ) -> None : ...
80+ def deregister_memory (self , descs : MemoryDescs ) -> None :
81+ """De-register a set of memory descriptors on the agent."""
82+ ...
7483
7584 @abstractmethod
76- def load_remote_agent (self , name : str , agent_desc : str ) -> None : ...
85+ def load_remote_agent (self , name : str , agent_desc : str ) -> None :
86+ """
87+ Load information about a remote agent specified by name.
88+
89+ Args:
90+ name (str): The remote agent's identifier.
91+ agent_desc (str): A serialized description of the agent.
92+ """
93+ ...
7794
7895 @abstractmethod
79- def get_local_agent_desc (self ) -> str : ...
96+ def get_local_agent_desc (self ) -> str :
97+ """Return the serialized description of this agent."""
98+ ...
8099
81100 @abstractmethod
82- def invalidate_remote_agent (self , name : str ) -> None : ...
101+ def invalidate_remote_agent (self , name : str ) -> None :
102+ """Invalidate any cached information about the specified remote agent."""
103+ ...
83104
84105 @abstractmethod
85- def submit_transfer_requests (self , request : TransferRequest ) -> TransferStatus : ...
106+ def submit_transfer_requests (self , request : TransferRequest ) -> TransferStatus :
107+ """Submit transfer tasks to the agent based on a request."""
108+ ...
86109
87110 @abstractmethod
88- def notify_sync_message (self , name : str , sync_message : str ) -> None : ...
111+ def notify_sync_message (self , name : str , sync_message : str ) -> None :
112+ """Send a synchronization message to the specified remote agent."""
113+ ...
89114
90115 @abstractmethod
91- def check_remote_descs (self , name : str , memory_descs : List [int ]) -> bool : ...
116+ def check_remote_descs (self , name : str , memory_descs : List [int ]) -> bool :
117+ """
118+ Verify the remote agent's memory descriptors.
119+ """
120+ ...
92121
93122
94123# RegMemoryDescs is Python-only (used for registration with name field)
0 commit comments