Skip to content

Commit e4c7078

Browse files
authored
[None][fix] enable hmac in RPC (#9745)
Signed-off-by: Superjomn <[email protected]>
1 parent 2645a78 commit e4c7078

File tree

9 files changed

+57
-7
lines changed

9 files changed

+57
-7
lines changed

tensorrt_llm/executor/ray_executor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,8 @@ def __init__(self,
8282
is_llm_executor=is_llm_executor)
8383

8484
self.init_rpc_executor()
85+
# Inject the generated HMAC key into worker_kwargs for workers
86+
worker_kwargs['hmac_key'] = self.hmac_key
8587
worker_kwargs['rpc_addr'] = self.rpc_addr
8688
self.create_workers(RayGPUWorker, worker_kwargs)
8789
self.setup_engine_remote()

tensorrt_llm/executor/ray_gpu_worker.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@ def __init__(
168168
tokenizer: Optional[TokenizerBase] = None,
169169
llm_args: Optional[BaseLlmArgs] = None,
170170
rpc_addr: Optional[str] = None,
171+
hmac_key: Optional[bytes] = None,
171172
) -> None:
172173
global logger
173174
from tensorrt_llm.logger import logger
@@ -191,7 +192,7 @@ def __init__(
191192
if rpc_addr is None:
192193
raise RuntimeError(
193194
"RPC mode enabled but no rpc_addr provided to RayGPUWorker")
194-
self.init_rpc_worker(self.global_rank, rpc_addr)
195+
self.init_rpc_worker(self.global_rank, rpc_addr, hmac_key)
195196
self.start_rpc_server()
196197

197198
def setup_engine(self):

tensorrt_llm/executor/rpc/rpc_client.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,8 @@ def __init__(self,
108108
self._client_socket = ZeroMqQueue(address=(address, hmac_key),
109109
is_server=False,
110110
is_async=True,
111-
use_hmac_encryption=False,
111+
use_hmac_encryption=hmac_key
112+
is not None,
112113
socket_type=socket_type,
113114
name="rpc_client")
114115
self._pending_futures = {}

tensorrt_llm/executor/rpc/rpc_server.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,8 @@ def bind(self, address: str = "tcp://*:5555") -> None:
108108
self._client_socket = ZeroMqQueue(address=(address, self._hmac_key),
109109
is_server=True,
110110
is_async=True,
111-
use_hmac_encryption=False,
111+
use_hmac_encryption=self._hmac_key
112+
is not None,
112113
socket_type=socket_type,
113114
name="rpc_server")
114115
logger.info(f"RPCServer is bound to {self._address}")

tensorrt_llm/executor/rpc_proxy.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ def __init__(
4848

4949
self._create_mpi_session(model_world_size, mpi_session)
5050

51+
# Inject the generated HMAC key into worker_kwargs for workers
52+
worker_kwargs['hmac_key'] = self.hmac_key
5153
self.worker_kwargs = worker_kwargs
5254

5355
self.launch_workers()

tensorrt_llm/executor/rpc_proxy_mixin.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import asyncio
22
import atexit
33
import json
4+
import os
45
import threading
56
from typing import Callable, List, Optional
67

@@ -29,7 +30,8 @@ class RpcExecutorMixin:
2930

3031
def init_rpc_executor(self):
3132
self.rpc_addr = get_unique_ipc_addr()
32-
self.rpc_client = RPCClient(self.rpc_addr)
33+
self.hmac_key = os.urandom(32)
34+
self.rpc_client = RPCClient(self.rpc_addr, hmac_key=self.hmac_key)
3335

3436
self._results = {}
3537
self._shutdown_event = threading.Event()

tensorrt_llm/executor/rpc_worker.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,10 @@ def main_task(
155155
color="yellow")
156156
# Step 2: Create the RPC service, it will expose all the APIs of the worker as remote call to the client
157157
# Set num_workers to larger than 1 since there are some streaming tasks runs infinitely, such as await_responses_async.
158-
rpc_server = RPCServer(worker, num_workers=worker.num_workers)
158+
hmac_key = kwargs.get("hmac_key")
159+
rpc_server = RPCServer(worker,
160+
num_workers=worker.num_workers,
161+
hmac_key=hmac_key)
159162
rpc_server.bind(rpc_addr)
160163
rpc_server.start()
161164
logger_debug(f"[worker] RPC server {mpi_rank()} is started",

tensorrt_llm/executor/rpc_worker_mixin.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,11 @@ class RpcWorkerMixin:
2525
# This can be overridden by setting num_workers in the inheriting class
2626
NUM_WORKERS = 6
2727

28-
def init_rpc_worker(self, rank: int, rpc_addr: Optional[str]):
28+
def init_rpc_worker(self, rank: int, rpc_addr: Optional[str], hmac_key: Optional[bytes] = None):
2929
if rpc_addr is None:
3030
raise RuntimeError("RPC mode enabled but no rpc_addr provided to worker")
3131

32+
self.hmac_key = hmac_key
3233
self.rank = rank
3334
self.shutdown_event = Event()
3435
self._response_queue = Queue()
@@ -41,7 +42,7 @@ def start_rpc_server(self):
4142
if self.rank == 0:
4243
# Use num_workers if set on the instance, otherwise use class default
4344
num_workers = getattr(self, "num_workers", RpcWorkerMixin.NUM_WORKERS)
44-
self.rpc_server = RPCServer(self, num_workers=num_workers)
45+
self.rpc_server = RPCServer(self, num_workers=num_workers, hmac_key=self.hmac_key)
4546
self.rpc_server.bind(self.rpc_addr)
4647
self.rpc_server.start()
4748

tests/unittest/executor/test_rpc_proxy.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,43 @@ def test_tp2(self, num_reqs):
9595
assert similar(tokenizer.decode(result.outputs[0].token_ids),
9696
'E F G H I J K L')
9797

98+
def test_hmac_key_generation(self):
99+
"""Test that HMAC key is automatically generated and properly propagated."""
100+
tokenizer = TransformersTokenizer.from_pretrained(model_path)
101+
prompt = "A B C D"
102+
prompt_token_ids = tokenizer.encode(prompt)
103+
max_tokens = 8
104+
105+
with self.create_proxy(tp_size=1) as proxy:
106+
assert proxy.hmac_key is not None, "HMAC key should be generated"
107+
assert len(
108+
proxy.hmac_key
109+
) == 32, f"HMAC key should be 32 bytes, got {len(proxy.hmac_key)}"
110+
111+
# Verify key is properly stored in worker_kwargs
112+
assert 'hmac_key' in proxy.worker_kwargs, "HMAC key should be in worker_kwargs"
113+
assert proxy.worker_kwargs[
114+
'hmac_key'] is not None, "HMAC key in worker_kwargs should not be None"
115+
116+
# Verify both references point to the same key object
117+
assert proxy.hmac_key is proxy.worker_kwargs['hmac_key'], \
118+
"HMAC key should be the same object in both locations"
119+
120+
logger_debug(
121+
f"[Test] HMAC key verified: length={len(proxy.hmac_key)} bytes",
122+
color="green")
123+
124+
# Verify RPC communication works with the generated key
125+
sampling_params = SamplingParams(max_tokens=max_tokens)
126+
result = proxy.generate(prompt_token_ids, sampling_params)
127+
assert similar(
128+
tokenizer.decode(result.outputs[0].token_ids), 'E F G H I J K L'
129+
), "Generation should work with auto-generated HMAC key"
130+
131+
logger_debug(
132+
f"[Test] HMAC key test passed: RPC communication successful",
133+
color="green")
134+
98135

99136
if __name__ == "__main__":
100137
TestRpcProxy().test_tp1(20)

0 commit comments

Comments
 (0)