-
Notifications
You must be signed in to change notification settings - Fork 2k
[TRTLLM-9527][feat] Python transceiver components (step 2) #10494
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[TRTLLM-9527][feat] Python transceiver components (step 2) #10494
Conversation
Signed-off-by: Shixiaowei02 <[email protected]>
📝 WalkthroughWalkthroughIntroduces a disaggregated KV cache transfer subsystem comprising base abstractions for transfer agents, session management for bidirectional KV slice transfer, ZeroMQ-based messaging infrastructure, and transfer agent implementations with optional C++ bindings and Python fallbacks. Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~30 minutes 🚥 Pre-merge checks | ✅ 1 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 12
🤖 Fix all issues with AI agents
In @tensorrt_llm/_torch/disaggregation/base/agent.py:
- Around line 1-4: Add the required NVIDIA copyright header block at the very
top of this source file (before any imports) using the project’s standard header
template (including SPDX/license tag, year(s), and "NVIDIA CORPORATION &
AFFILIATES" attribution); ensure the header matches other TensorRT-LLM files and
update the year range if needed so the file (agent.py) conforms to the
repository’s licensing/copyright guidelines.
- Line 66: The wait method signature uses Python 3.10 union syntax ("def
wait(self, timeout: float | None = None) -> None") which is incompatible with
the project's Python 3.8+ requirement; change the annotation to use
typing.Optional (e.g., "timeout: Optional[float] = None") and add an import for
Optional from typing if not already present, leaving the method name wait and
behavior unchanged.
In @tensorrt_llm/_torch/disaggregation/base/kv_transfer.py:
- Around line 89-94: The docstring for the abstract method receive is misplaced
after the method body; move the triple-quoted docstring so it immediately
follows the receive(self, slice: KVSlice) -> TaskIdType: signature (so it
becomes the method's docstring), ensure it uses proper triple quotes and
references behavior (async receive slice from the peer, returns TaskIdType, task
state polled via poll_task()), and keep the @abstractmethod decorator and
signature intact.
- Around line 1-9: Add the standard NVIDIA copyright header to the top of this
source file (before any imports/definitions) so it matches other project files;
insert the project's canonical header block (including copyright year, NVIDIA
ownership and any SPDX/license tag your repo uses) at the top of kv_transfer.py
(where DisaggregatedParams is imported) to satisfy the coding guidelines.
- Around line 66-71: The docstring currently sits after the abstract method
signature for send (def send(self, slice: KVSlice) -> TaskIdType) and is
therefore a no-op; move the triple-quoted string into the send method body as
its proper docstring (directly under the signature) so the abstractmethod has an
explanatory docstring describing the async send behavior and that the returned
TaskIdType can be polled via poll_task(); remove the standalone string
expression left after the method signature and ensure the method still respects
@abstractmethod (body can be the docstring alone or followed by raise
NotImplementedError).
In @tensorrt_llm/_torch/disaggregation/native/messenger.py:
- Around line 1-9: Add the standard NVIDIA copyright header to the top of the
tensorrt_llm._torch.disaggregation.native.messenger module (before any imports
or code), matching the project's required header format and year range; ensure
the header appears as the first lines of the file so it precedes the existing
imports and module code.
- Around line 96-104: The ZMQMessenger constructor can leave self._endpoint
uninitialized when endpoint is falsy, causing AttributeError later (e.g., in the
endpoint property); initialize self._endpoint to a default (None or empty
string) before the conditional that checks endpoint, then preserve the existing
logic that sets self._endpoint inside the if/elif branches for "ROUTER"/"REP"
and "DEALER"/"REQ" so the attribute is always defined whether or not an endpoint
was provided.
- Line 81: The endpoint default is evaluated at import time because the f-string
calls get_local_ip() in the function signature; change the __init__ signature to
use endpoint: Optional[str] = None and inside __init__ resolve the default by
computing f"tcp://{get_local_ip()}:*" when endpoint is None (or
validate/overwrite endpoint accordingly). Update any docstring or comments for
the Messenger class/method to reflect that endpoint is now resolved at
instantiation time and ensure references to the parameter in methods use the
instance attribute (e.g., self.endpoint) set from the resolved value.
In @tensorrt_llm/_torch/disaggregation/native/utils.py:
- Around line 1-34: Add the required NVIDIA copyright header to the top of this
source file (the one defining get_local_ip) including the appropriate year of
the latest meaningful modification and the standard NVIDIA header text used
across the project; place it above all imports and code, ensure formatting
matches other TensorRT-LLM files, and update the year if this file was modified
after the existing header template's year.
In @tensorrt_llm/_torch/disaggregation/nixl/agent.py:
- Around line 1-13: Add the required NVIDIA copyright header at the top of the
file before the existing module docstring; update
tensorrt_llm/_torch/disaggregation/nixl/agent.py by inserting the standard
multi-line copyright header as the first lines (above the triple-quoted module
docstring that currently describes NIXL Transfer Agent implementations and
before the import of time and nvtx_range).
- Around line 246-249: The code uses assert to check transfer result in the
block that calls self.agent.transfer(handle); replace the assert with explicit
runtime error handling: inspect the returned status from
self.agent.transfer(handle) and if it equals "ERROR" raise a descriptive
exception (e.g., RuntimeError or a custom TransferError) including context such
as the handle and/or status; otherwise return NixlTransferStatus(self.agent,
handle). Ensure this change is applied where status =
self.agent.transfer(handle) is used (in the method that returns
NixlTransferStatus).
- Around line 217-218: The load_remote_agent method currently accepts a name
parameter but never uses it; update load_remote_agent to forward name to the
nixl library call (call add_remote_agent with both name and agent_desc) if the
Python nixl binding supports a name argument, mirroring
BindingsNixlTransferAgent, or if the binding does not accept a name then add a
short comment/docstring in load_remote_agent explaining why name is
intentionally ignored and that the C++ binding handles naming instead.
🧹 Nitpick comments (11)
tensorrt_llm/_torch/disaggregation/native/utils.py (2)
2-11:socketimport is scoped inside try block but used outside.The
socketmodule is imported at line 3 inside the first try block, but it's used again at lines 27-28 in the third fallback. If the first try block raises an exception before the import succeeds (unlikely but possible), or if it's skipped due to code refactoring, aNameErrorwill occur.Move import to module level
+import socket + + def get_local_ip() -> str: try: - import socket - with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
23-24: Catching blindExceptionsilently discards useful diagnostics.Consider catching a more specific exception (e.g.,
ImportErrorfor missingnetifaces,KeyError,OSError) or at least logging the exception for debugging purposes.Proposed fix
- except Exception: - pass + except (ImportError, KeyError, OSError): + passtensorrt_llm/_torch/disaggregation/native/messenger.py (3)
121-125: Prefer unpacking over list concatenation.Using
[recipient, *messages]is more idiomatic and slightly more efficient than[recipient] + messages.Proposed fix
def send(self, messages: list[bytes], recipient: Optional[bytes] = None): if recipient: - self._socket.send_multipart([recipient] + messages) + self._socket.send_multipart([recipient, *messages]) else: self._socket.send_multipart(messages)
153-158: Broad exception catch may mask programming errors.Catching all
Exceptiontypes could hide bugs. Consider being more specific or at least including the exception type in the log message.Include exception type in error log
except zmq.ZMQError as e: logger.error(f"ZMQ Error in listener: {e}") continue except Exception as e: - logger.error(f"Error in listener: {e}") + logger.error(f"Unexpected {type(e).__name__} in listener: {e}") continue
182-183:__del__callingstop()may cause issues during interpreter shutdown.During interpreter shutdown, the
zmqmodule orloggermay already be garbage collected when__del__is called, potentially causing errors. Consider wrapping in a try-except or checking if resources are valid.Proposed fix
def __del__(self): - self.stop() + try: + self.stop() + except Exception: + pass # Suppress errors during interpreter shutdowntensorrt_llm/_torch/disaggregation/base/kv_transfer.py (2)
49-52: Empty abstract base classes serve no purpose.
SenderBaseandReceiverBaseare ABCs with no abstract methods or properties. Consider either adding abstract methods they should define, or using regular base classes if they're just type markers.
55-57: Unused constructor argumentssenderandreceiver.The
senderandreceiverarguments are accepted but never used. If they're intended for future use, consider storing them as instance attributes or adding a comment explaining the intent.Store the argument for future use
class TxSessionBase(ABC): def __init__(self, sender: SenderBase, args: SessionArgsBase): + self._sender = sender self._base_args = args ... class RxSessionBase(ABC): def __init__(self, receiver: ReceiverBase, args: SessionArgsBase): + self._receiver = receiver self._base_args = argsAlso applies to: 78-80
tensorrt_llm/_torch/disaggregation/base/agent.py (1)
31-40: Consider usingEnumforTransferOpandMemoryTypefallbacks.Using class attributes with string values works but loses type safety. Consider using
Enumto match the likely C++ binding behavior and provide better IDE support.Example using Enum
+ from enum import Enum + - class TransferOp: - READ = "READ" - WRITE = "WRITE" + class TransferOp(Enum): + READ = "READ" + WRITE = "WRITE" - class MemoryType: - DRAM = "DRAM" - VRAM = "VRAM" - BLK = "BLK" - OBJ = "OBJ" - FILE = "FILE" + class MemoryType(Enum): + DRAM = "DRAM" + VRAM = "VRAM" + BLK = "BLK" + OBJ = "OBJ" + FILE = "FILE"tensorrt_llm/_torch/disaggregation/nixl/agent.py (3)
96-101: Type annotation saysbytesbut code handlesstrcase.The type hint declares
agent_desc: bytes, but line 99 handles the case where it might be a string. Either update the type hint toUnion[bytes, str]or remove the isinstance check if only bytes are expected.Option 1: Update type hint to reflect actual usage
- def load_remote_agent(self, name: str, agent_desc: bytes): + def load_remote_agent(self, name: str, agent_desc: Union[bytes, str]): """Load a remote agent by its descriptor (bytes).""" # AgentDesc expects std::string which can hold binary data desc_str = agent_desc if isinstance(agent_desc, bytes) else agent_desc.encode()Add
Unionto imports at the top:from typing import Union
22-22: Remove unusednoqadirectives.Static analysis indicates the
# noqa: E402comments are unnecessary as the E402 rule is not enabled.Proposed fix
- import tensorrt_llm.tensorrt_llm_transfer_agent_binding as _agent_binding # noqa: E402 + import tensorrt_llm.tensorrt_llm_transfer_agent_binding as _agent_binding ... - from nixl import nixl_agent, nixl_agent_config, nixl_xfer_handle # noqa: E402 + from nixl import nixl_agent, nixl_agent_config, nixl_xfer_handleAlso applies to: 173-173
186-198: Exponential backoff with sleep is reasonable but consider a maximum iteration count.The wait loop uses exponential backoff which is good, but there's no maximum iteration limit. In pathological cases, this could loop indefinitely if the status never changes from "PROC".
Add maximum wait time or iteration limit
def wait(self): status = "PROC" sleep_time = 0.0001 # 0.1ms max_sleep_time = 0.01 # 10ms + max_total_wait = 300 # 5 minutes + total_waited = 0 while status == "PROC": status = self.agent.check_xfer_state(self.handle) if status == "ERROR": return False # transfer failed - # sleep(0.1) # sleep to release GIL time.sleep(sleep_time) + total_waited += sleep_time + if total_waited >= max_total_wait: + return False # timeout sleep_time = min(sleep_time * 2, max_sleep_time) return True
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (9)
tensorrt_llm/_torch/disaggregation/__init__.pytensorrt_llm/_torch/disaggregation/base/__init__.pytensorrt_llm/_torch/disaggregation/base/agent.pytensorrt_llm/_torch/disaggregation/base/kv_transfer.pytensorrt_llm/_torch/disaggregation/native/__init__.pytensorrt_llm/_torch/disaggregation/native/messenger.pytensorrt_llm/_torch/disaggregation/native/utils.pytensorrt_llm/_torch/disaggregation/nixl/__init__.pytensorrt_llm/_torch/disaggregation/nixl/agent.py
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py: The code developed for TensorRT-LLM should conform to Python 3.8+
Indent Python code with 4 spaces. Do not use tabs
Always maintain the namespace when importing Python modules, even if only one class or function from a module is used
Python filenames should use snake_case (e.g.,some_file.py)
Python classes should use PascalCase (e.g.,class SomeClass)
Python functions and methods should use snake_case (e.g.,def my_awesome_function():)
Python local variables should use snake_case, with prefixkfor variable names that start with a number (e.g.,k_99th_percentile)
Python global variables should use upper snake_case with prefixG(e.g.,G_MY_GLOBAL)
Python constants should use upper snake_case (e.g.,MY_CONSTANT)
Avoid shadowing variables declared in an outer scope in Python
Initialize all externally visible members of a Python class in the constructor
For Python interfaces that may be used outside a file, prefer docstrings over comments
Use comments in Python for code within a function, or interfaces that are local to a file
Use Google-style docstrings for Python classes and functions, which can be parsed by Sphinx
Python attributes and variables can be documented inline with the format"""<type>: Description"""
Avoid using reflection in Python when functionality can be easily achieved without reflection
When using try-except blocks in Python, limit the except clause to the smallest set of errors possible
When using try-except blocks in Python to handle multiple possible variable types (duck-typing), keep the body of the try as small as possible and use the else block for the main logic
Files:
tensorrt_llm/_torch/disaggregation/native/utils.pytensorrt_llm/_torch/disaggregation/base/kv_transfer.pytensorrt_llm/_torch/disaggregation/nixl/agent.pytensorrt_llm/_torch/disaggregation/native/messenger.pytensorrt_llm/_torch/disaggregation/base/agent.py
**/*.{cpp,cc,cxx,h,hpp,hxx,cu,cuh,py}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
All TensorRT-LLM source files (.cpp, .h, .cu, .py, and other source files) should contain an NVIDIA copyright header with the year of latest meaningful modification
Files:
tensorrt_llm/_torch/disaggregation/native/utils.pytensorrt_llm/_torch/disaggregation/base/kv_transfer.pytensorrt_llm/_torch/disaggregation/nixl/agent.pytensorrt_llm/_torch/disaggregation/native/messenger.pytensorrt_llm/_torch/disaggregation/base/agent.py
🧬 Code graph analysis (4)
tensorrt_llm/_torch/disaggregation/base/kv_transfer.py (2)
tensorrt_llm/executor/result.py (1)
request_id(727-728)tensorrt_llm/_torch/disaggregation/native/messenger.py (4)
send(24-30)send(121-125)receive(42-47)receive(131-132)
tensorrt_llm/_torch/disaggregation/nixl/agent.py (3)
tensorrt_llm/_utils.py (1)
nvtx_range(891-910)tensorrt_llm/_torch/disaggregation/base/agent.py (15)
BaseTransferAgent(68-91)RegMemoryDescs(96-98)TransferRequest(54-59)TransferStatus(61-66)MemoryType(35-40)MemoryDescs(49-51)wait(66-66)register_memory(70-70)deregister_memory(73-73)load_remote_agent(76-76)get_local_agent_desc(79-79)invalidate_remote_agent(82-82)check_remote_descs(91-91)notify_sync_message(88-88)submit_transfer_requests(85-85)cpp/include/tensorrt_llm/executor/transferAgent.h (2)
AgentDesc(199-428)TransferState(277-426)
tensorrt_llm/_torch/disaggregation/native/messenger.py (1)
tensorrt_llm/_torch/disaggregation/native/utils.py (1)
get_local_ip(1-34)
tensorrt_llm/_torch/disaggregation/base/agent.py (1)
tensorrt_llm/_torch/disaggregation/nixl/agent.py (8)
is_completed(50-52)is_completed(182-184)wait(55-57)wait(186-198)register_memory(86-89)register_memory(210-212)check_remote_descs(120-125)check_remote_descs(226-227)
🪛 Ruff (0.14.10)
tensorrt_llm/_torch/disaggregation/native/utils.py
23-24: try-except-pass detected, consider logging the exception
(S110)
23-23: Do not catch blind exception: Exception
(BLE001)
tensorrt_llm/_torch/disaggregation/base/kv_transfer.py
49-49: SenderBase is an abstract base class, but it has no abstract methods or properties
(B024)
52-52: ReceiverBase is an abstract base class, but it has no abstract methods or properties
(B024)
56-56: Unused method argument: sender
(ARG002)
79-79: Unused method argument: receiver
(ARG002)
tensorrt_llm/_torch/disaggregation/nixl/agent.py
22-22: Unused noqa directive (non-enabled: E402)
Remove unused noqa directive
(RUF100)
72-75: Avoid specifying long messages outside the exception class
(TRY003)
173-173: Unused noqa directive (non-enabled: E402)
Remove unused noqa directive
(RUF100)
217-217: Unused method argument: name
(ARG002)
tensorrt_llm/_torch/disaggregation/native/messenger.py
76-76: Avoid specifying long messages outside the exception class
(TRY003)
116-116: Avoid specifying long messages outside the exception class
(TRY003)
123-123: Consider [recipient, *messages] instead of concatenation
Replace with [recipient, *messages]
(RUF005)
136-136: Avoid specifying long messages outside the exception class
(TRY003)
156-156: Do not catch blind exception: Exception
(BLE001)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
| @abstractmethod | ||
| def send(self, slice: KVSlice) -> TaskIdType: ... | ||
|
|
||
| """ | ||
| Async send slice to the peer. return the task id. Task state can be polled by poll_task(). | ||
| """ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Misplaced docstring appears after method definition.
The docstring at lines 69-71 appears after the send method signature instead of inside it, making it a standalone string expression with no effect. Move it inside the method body.
Proposed fix
@abstractmethod
- def send(self, slice: KVSlice) -> TaskIdType: ...
-
- """
- Async send slice to the peer. return the task id. Task state can be polled by poll_task().
- """
+ def send(self, slice: KVSlice) -> TaskIdType:
+ """Async send slice to the peer. Return the task id. Task state can be polled by poll_task()."""
+ ...🤖 Prompt for AI Agents
In @tensorrt_llm/_torch/disaggregation/base/kv_transfer.py around lines 66 - 71,
The docstring currently sits after the abstract method signature for send (def
send(self, slice: KVSlice) -> TaskIdType) and is therefore a no-op; move the
triple-quoted string into the send method body as its proper docstring (directly
under the signature) so the abstractmethod has an explanatory docstring
describing the async send behavior and that the returned TaskIdType can be
polled via poll_task(); remove the standalone string expression left after the
method signature and ensure the method still respects @abstractmethod (body can
be the docstring alone or followed by raise NotImplementedError).
| @abstractmethod | ||
| def receive(self, slice: KVSlice) -> TaskIdType: ... | ||
|
|
||
| """ | ||
| Async receive slice from the peer. return the task id. Task state can be polled by poll_task(). | ||
| """ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Misplaced docstring appears after method definition.
Same issue as above - the docstring at lines 92-94 is not attached to the receive method.
Proposed fix
@abstractmethod
- def receive(self, slice: KVSlice) -> TaskIdType: ...
-
- """
- Async receive slice from the peer. return the task id. Task state can be polled by poll_task().
- """
+ def receive(self, slice: KVSlice) -> TaskIdType:
+ """Async receive slice from the peer. Return the task id. Task state can be polled by poll_task()."""
+ ...🤖 Prompt for AI Agents
In @tensorrt_llm/_torch/disaggregation/base/kv_transfer.py around lines 89 - 94,
The docstring for the abstract method receive is misplaced after the method
body; move the triple-quoted docstring so it immediately follows the
receive(self, slice: KVSlice) -> TaskIdType: signature (so it becomes the
method's docstring), ensure it uses proper triple quotes and references behavior
(async receive slice from the peer, returns TaskIdType, task state polled via
poll_task()), and keep the @abstractmethod decorator and signature intact.
| if endpoint: | ||
| if mode in ["ROUTER", "REP"]: | ||
| self._socket.bind(endpoint) | ||
| self._endpoint = self._socket.getsockopt_string(zmq.LAST_ENDPOINT) | ||
| elif mode in ["DEALER", "REQ"]: | ||
| self._socket.connect(endpoint) | ||
| self._endpoint = endpoint | ||
|
|
||
| logger.debug(f"Initializing ZMQMessenger, mode={mode}, endpoint={self._endpoint}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self._endpoint may be uninitialized if endpoint is falsy.
If endpoint is None or empty, the if endpoint: block is skipped and self._endpoint is never assigned, leading to an AttributeError at line 104 and in the endpoint property.
Initialize `_endpoint` before the conditional
self._control_socket.bind(inproc_endpoint)
self._internal_socket.connect(inproc_endpoint)
+ self._endpoint = None
if endpoint:
if mode in ["ROUTER", "REP"]:
self._socket.bind(endpoint)
self._endpoint = self._socket.getsockopt_string(zmq.LAST_ENDPOINT)
elif mode in ["DEALER", "REQ"]:
self._socket.connect(endpoint)
self._endpoint = endpoint
+ else:
+ raise ValueError("endpoint must be provided")
- logger.debug(f"Initializing ZMQMessenger, mode={mode}, endpoint={self._endpoint}")
+ if self._endpoint:
+ logger.debug(f"Initializing ZMQMessenger, mode={mode}, endpoint={self._endpoint}")🤖 Prompt for AI Agents
In @tensorrt_llm/_torch/disaggregation/native/messenger.py around lines 96 -
104, The ZMQMessenger constructor can leave self._endpoint uninitialized when
endpoint is falsy, causing AttributeError later (e.g., in the endpoint
property); initialize self._endpoint to a default (None or empty string) before
the conditional that checks endpoint, then preserve the existing logic that sets
self._endpoint inside the if/elif branches for "ROUTER"/"REP" and "DEALER"/"REQ"
so the attribute is always defined whether or not an endpoint was provided.
| def load_remote_agent(self, name: str, agent_desc: bytes): | ||
| self.agent.add_remote_agent(agent_desc) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
# First, let's look at the file and the method in context
cat -n tensorrt_llm/_torch/disaggregation/nixl/agent.py | sed -n '210,230p'Repository: NVIDIA/TensorRT-LLM
Length of output: 1063
🏁 Script executed:
# Search for the method definition and any usages of the name parameter
rg -A 10 'def load_remote_agent' tensorrt_llm/_torch/disaggregation/nixl/agent.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 1408
🏁 Script executed:
# Check if this class inherits from a base class and if there's an interface definition
rg -B 20 'def load_remote_agent' tensorrt_llm/_torch/disaggregation/nixl/agent.py | head -40Repository: NVIDIA/TensorRT-LLM
Length of output: 1677
🏁 Script executed:
# Search for any base class definitions or abstract interfaces that might define this method
fd -t py . tensorrt_llm/_torch/disaggregation/ | xargs rg 'class.*Agent' -A 5Repository: NVIDIA/TensorRT-LLM
Length of output: 19669
🏁 Script executed:
# Search for add_remote_agent usage and definition
rg -B 5 -A 5 'add_remote_agent' tensorrt_llm/_torch/disaggregation/Repository: NVIDIA/TensorRT-LLM
Length of output: 995
🏁 Script executed:
# Look at the full BindingsNixlTransferAgent class definition
cat -n tensorrt_llm/_torch/disaggregation/nixl/agent.py | sed -n '150,230p'Repository: NVIDIA/TensorRT-LLM
Length of output: 3704
🏁 Script executed:
# Search the C++ bindings to see what methods are exposed
rg 'add_remote_agent|load_remote_agent' cpp/tensorrt_llm/executor/cache_transmission/Repository: NVIDIA/TensorRT-LLM
Length of output: 1526
🏁 Script executed:
# Check the abstract base class definition to see if load_remote_agent is defined there
cat -n tensorrt_llm/_torch/disaggregation/base/agent.py | grep -A 100 'class BaseTransferAgent'Repository: NVIDIA/TensorRT-LLM
Length of output: 1265
🏁 Script executed:
# Search for any documentation or implementation of the nixl library's add_remote_agent
rg -B 3 -A 3 'def add_remote_agent' tensorrt_llm/Repository: NVIDIA/TensorRT-LLM
Length of output: 45
🏁 Script executed:
# Check if there are any other usages of load_remote_agent to understand the expected interface
rg 'load_remote_agent' tensorrt_llm/ --type pyRepository: NVIDIA/TensorRT-LLM
Length of output: 780
Pass the name parameter to the nixl library's agent method or document its omission.
The name parameter is declared but unused. The base class interface requires this parameter, and the C++ bindings implementation (BindingsNixlTransferAgent) passes it to load_remote_agent. For consistency, either pass name to add_remote_agent if supported by the Python nixl library, or document why it is not needed.
🧰 Tools
🪛 Ruff (0.14.10)
217-217: Unused method argument: name
(ARG002)
🤖 Prompt for AI Agents
In @tensorrt_llm/_torch/disaggregation/nixl/agent.py around lines 217 - 218, The
load_remote_agent method currently accepts a name parameter but never uses it;
update load_remote_agent to forward name to the nixl library call (call
add_remote_agent with both name and agent_desc) if the Python nixl binding
supports a name argument, mirroring BindingsNixlTransferAgent, or if the binding
does not accept a name then add a short comment/docstring in load_remote_agent
explaining why name is intentionally ignored and that the C++ binding handles
naming instead.
| ) | ||
| status = self.agent.transfer(handle) | ||
| assert status != "ERROR" | ||
| return NixlTransferStatus(self.agent, handle) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Avoid assert for runtime error handling in production code.
The assert statement can be disabled with -O flag. Use an explicit exception for error handling that should always be checked.
Proposed fix
handle = self.agent.initialize_xfer(
request.op,
src_xfer_descs,
dst_xfer_descs,
request.remote_name,
request.sync_message,
)
status = self.agent.transfer(handle)
- assert status != "ERROR"
+ if status == "ERROR":
+ raise RuntimeError(f"Transfer initialization failed for request to {request.remote_name}")
return NixlTransferStatus(self.agent, handle)🤖 Prompt for AI Agents
In @tensorrt_llm/_torch/disaggregation/nixl/agent.py around lines 246 - 249, The
code uses assert to check transfer result in the block that calls
self.agent.transfer(handle); replace the assert with explicit runtime error
handling: inspect the returned status from self.agent.transfer(handle) and if it
equals "ERROR" raise a descriptive exception (e.g., RuntimeError or a custom
TransferError) including context such as the handle and/or status; otherwise
return NixlTransferStatus(self.agent, handle). Ensure this change is applied
where status = self.agent.transfer(handle) is used (in the method that returns
NixlTransferStatus).
| assert status != "ERROR" | ||
| return NixlTransferStatus(self.agent, handle) | ||
|
|
||
| except ImportError: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's better to move the try-catch for import block into init.py
| def _convert_memory_type(self, py_type: str) -> "MemoryType": | ||
| """Convert Python memory type string to C++ MemoryType.""" | ||
| type_map = { | ||
| "DRAM": MemoryType.DRAM, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
class MemoryType(enum):
...
then MemoryType[py_type] should suffice.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we have unit tests for all new classes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK. I will mark this PR as draft status and add tests. Thanks
1dd0387 to
4b8c619
Compare
Signed-off-by: Shixiaowei02 <[email protected]>
4b8c619 to
cd618b9
Compare
Signed-off-by: Shixiaowei02 <[email protected]>
Summary by CodeRabbit
Release Notes
✏️ Tip: You can customize this high-level summary in your review settings.