Skip to content

Commit 1d97a59

Browse files
committed
add rpc test list
Signed-off-by: chunweiy <chunweiy@nvidia.com> Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com>
1 parent 481e784 commit 1d97a59

File tree

13 files changed

+175
-142
lines changed

13 files changed

+175
-142
lines changed

tensorrt_llm/executor/ipc.py

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -194,38 +194,6 @@ async def put_async_noblock(self, obj: Any):
194194
logger.error(traceback.format_exc())
195195
raise e
196196

197-
async def put_async_with_timeout(self, obj: Any, timeout: float = 5.0):
198-
"""
199-
Send an object with timeout to detect connection failures.
200-
201-
Args:
202-
obj: The object to send
203-
timeout: Timeout in seconds for the send operation
204-
205-
Raises:
206-
zmq.Again: If send operation times out (peer may be disconnected)
207-
Exception: Other send errors
208-
"""
209-
self.setup_lazily()
210-
try:
211-
if self.use_hmac_encryption:
212-
data = pickle.dumps(obj) # nosec B301
213-
signed_data = self._sign_data(data)
214-
# Use asyncio.wait_for to implement timeout instead of zmq.NOBLOCK
215-
await asyncio.wait_for(self.socket.send(signed_data),
216-
timeout=timeout)
217-
else:
218-
await asyncio.wait_for(self.socket.send_pyobj(obj),
219-
timeout=timeout)
220-
except asyncio.TimeoutError:
221-
# Convert timeout to zmq.Again to maintain compatibility with existing error handling
222-
raise zmq.Again(
223-
"Send operation timed out - peer may be disconnected")
224-
except Exception as e:
225-
logger.error(f"Error sending object: {e}")
226-
logger.error(traceback.format_exc())
227-
raise e
228-
229197
def get(self) -> Any:
230198
self.setup_lazily()
231199
return self._recv_data()
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# A Lightweight RPC
2+
This is a pure-Python lightweight RPC we build to simplify our existing IPC code in the orchestrator part. It provides multiple call modes (sync, async, future, streaming) and supports both IPC and TCP connections.
3+
4+
## Examples
5+
### Create Server and Client
6+
7+
```python
8+
from tensorrt_llm.executor.rpc import RPCServer, RPCClient
9+
10+
# Define your application
11+
class App:
12+
def add(self, a: int, b: int) -> int:
13+
return a + b
14+
15+
async def async_multiply(self, x: int, y: int) -> int:
16+
return x * y
17+
18+
# Create and start server
19+
app = App()
20+
with RPCServer(app) as server:
21+
server.bind("ipc:///tmp/my_rpc") # or "tcp://127.0.0.1:5555"
22+
server.start()
23+
24+
# Create client and make calls
25+
with RPCClient("ipc:///tmp/my_rpc") as client:
26+
result = client.add(5, 3).remote()
27+
print(result) # Output: 8
28+
```
29+
30+
### Different Remote Calls
31+
32+
#### Synchronous Call
33+
```python
34+
# Blocking call that waits for result
35+
result = client.add(10, 20).remote()
36+
# or with timeout
37+
result = client.add(10, 20).remote(timeout=5.0)
38+
```
39+
40+
#### Asynchronous Call
41+
```python
42+
# Async call that returns a coroutine
43+
result = await client.async_multiply(3, 4).remote_async()
44+
```
45+
46+
#### Future-based Call
47+
```python
48+
# Returns a concurrent.futures.Future
49+
future = client.add(1, 2).remote_future()
50+
# Get result later
51+
result = future.result()
52+
```
53+
54+
#### Fire-and-Forget Call
55+
```python
56+
# Send request without waiting for response
57+
client.submit_task(task_id=123).remote(need_response=False)
58+
```
59+
60+
#### Streaming Call
61+
```python
62+
# For async generator methods
63+
async for value in client.stream_data(n=10).remote_streaming():
64+
print(f"Received: {value}")
65+
```
66+
67+
### Error Handling
68+
```python
69+
from tensorrt_llm.executor.rpc import RPCError, RPCTimeout
70+
71+
try:
72+
result = client.risky_operation().remote(timeout=1.0)
73+
except RPCTimeout:
74+
print("Operation timed out")
75+
except RPCError as e:
76+
print(f"RPC Error: {e}")
77+
print(f"Original cause: {e.cause}")
78+
print(f"Traceback: {e.traceback}")
79+
```
80+
81+
### Graceful Shutdown
82+
```python
83+
# Shutdown server from client
84+
client.shutdown_server()
85+
```

tensorrt_llm/executor/rpc/rpc_client.py

Lines changed: 9 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -59,13 +59,13 @@ def remote_future(self,
5959
need_response: bool = True) -> concurrent.futures.Future:
6060
"""Remote call that returns a Future object."""
6161
return self._prepare_and_call(timeout, need_response, "future",
62-
"call_future")
62+
"_call_future")
6363

6464
def remote_streaming(self,
6565
timeout: Optional[float] = None) -> AsyncIterator[Any]:
6666
"""Remote call for streaming results."""
6767
# Streaming always needs a response
68-
return self._prepare_and_call(timeout, True, "async", "call_streaming")
68+
return self._prepare_and_call(timeout, True, "async", "_call_streaming")
6969

7070

7171
class RPCClient:
@@ -365,27 +365,8 @@ def _call_sync(self, method_name, *args, **kwargs):
365365
f"RPC Client _call_sync: Got result for {method_name}: {result}")
366366
return result
367367

368-
def call_async(self, name: str, *args, **kwargs) -> Any:
369-
"""
370-
Call a remote method asynchronously.
371-
372-
Args:
373-
name: Method name to call
374-
*args: Positional arguments
375-
**kwargs: Keyword arguments
376-
377-
Returns:
378-
Coroutine that can be awaited
379-
380-
Example:
381-
result = await client.call_async('remote_method', arg1, arg2, key=value)
382-
"""
383-
if "__rpc_params" not in kwargs:
384-
kwargs["__rpc_params"] = RPCParams(need_response=True)
385-
return self._call_async(name, *args, **kwargs)
386-
387-
def call_future(self, name: str, *args,
388-
**kwargs) -> concurrent.futures.Future:
368+
def _call_future(self, name: str, *args,
369+
**kwargs) -> concurrent.futures.Future:
389370
"""
390371
Call a remote method and return a Future.
391372
@@ -396,12 +377,6 @@ def call_future(self, name: str, *args,
396377
397378
Returns:
398379
A Future object that can be used to retrieve the result
399-
400-
Example:
401-
future = client.call_future('remote_method', arg1, arg2, key=value)
402-
result = future.result() # blocks until complete
403-
# or
404-
future.add_done_callback(lambda f: print(f.result()))
405380
"""
406381

407382
def _async_to_sync():
@@ -412,25 +387,8 @@ def _async_to_sync():
412387

413388
return self._executor.submit(_async_to_sync)
414389

415-
def call_sync(self, name: str, *args, **kwargs) -> Any:
416-
"""
417-
Call a remote method synchronously (blocking).
418-
419-
Args:
420-
name: Method name to call
421-
*args: Positional arguments
422-
**kwargs: Keyword arguments
423-
424-
Returns:
425-
The result of the remote method call
426-
427-
Example:
428-
result = client.call_sync('remote_method', arg1, arg2, key=value)
429-
"""
430-
return self._call_sync(name, *args, **kwargs)
431-
432-
async def call_streaming(self, name: str, *args,
433-
**kwargs) -> AsyncIterator[Any]:
390+
async def _call_streaming(self, name: str, *args,
391+
**kwargs) -> AsyncIterator[Any]:
434392
"""
435393
Call a remote async generator method and get streaming results.
436394
@@ -441,10 +399,6 @@ async def call_streaming(self, name: str, *args,
441399
442400
Yields:
443401
Results from the remote async generator
444-
445-
Example:
446-
async for result in client.call_streaming('streaming_task'):
447-
print(result)
448402
"""
449403
if self._server_stopped:
450404
raise RPCCancelled("Server is shutting down, request cancelled")
@@ -474,7 +428,7 @@ async def call_streaming(self, name: str, *args,
474428

475429
# Read streaming responses
476430
while True:
477-
logger_debug(f"RPC Client call_streaming waiting for response",
431+
logger_debug(f"RPC Client _call_streaming waiting for response",
478432
color="green")
479433
if timeout is None:
480434
response = await queue.get()
@@ -483,14 +437,14 @@ async def call_streaming(self, name: str, *args,
483437
timeout=timeout)
484438

485439
logger_debug(
486-
f"RPC Client call_streaming received [{response.stream_status}] response: {response}",
440+
f"RPC Client _call_streaming received [{response.stream_status}] response: {response}",
487441
color="green")
488442
if response.stream_status == 'start':
489443
# Start of stream
490444
continue
491445
elif response.stream_status == 'data':
492446
logger_debug(
493-
f"RPC Client call_streaming received data: {response.result}",
447+
f"RPC Client _call_streaming received data: {response.result}",
494448
color="green")
495449
yield response.result
496450
elif response.stream_status == 'end':

tensorrt_llm/executor/rpc_proxy.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import json
44
import os
55
import threading
6-
import time
76
from typing import Optional
87

98
from ..llmapi.llm_args import KvCacheConnectorConfig
@@ -71,7 +70,6 @@ def __init__(
7170
self.main_loop = None
7271

7372
self.launch_workers()
74-
time.sleep(1) # wait for the workers to launch
7573

7674
# Invoke model creation on the remote
7775
# TBD: Move model creation to the mpi task, or left in RPC?

tensorrt_llm/executor/rpc_worker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ def main_task(
230230
color="yellow")
231231
worker.setup_engine()
232232

233-
if mpi_rank() == 0:
233+
else:
234234
logger_debug(f"Worker {mpi_rank()} is creating the RPC service",
235235
color="yellow")
236236
# Step 2: Create the RPC service, it will expose all the APIs of the worker as remote call to the client

tests/integration/test_lists/test-db/l0_a10.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ l0_a10:
4646
- unittest/llmapi/test_serialization.py
4747
- unittest/llmapi/test_utils.py
4848
- unittest/llmapi/test_llm_args.py
49+
# executor
50+
- unittest/executor/test_rpc.py
4951
- condition:
5052
ranges:
5153
system_gpu_count:

tests/integration/test_lists/test-db/l0_a100.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@ l0_a100:
1616
- unittest/llmapi/test_llm_pytorch.py
1717
- unittest/llmapi/test_mpi_session.py # generic tests
1818
- unittest/trt/model_api/test_model_quantization.py
19+
# executor
20+
- unittest/executor/test_base_worker.py
21+
- unittest/executor/test_rpc_proxy.py
22+
- unittest/executor/test_rpc_worker.py
1923
- condition:
2024
ranges:
2125
system_gpu_count:

tests/unittest/executor/test_base_worker.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# isort: off
1313
sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/..")
1414
from utils.llm_data import llm_models_root
15+
from utils.util import skip_single_gpu
1516
# isort: on
1617

1718
from tensorrt_llm._torch.pyexecutor.config import update_executor_config
@@ -156,6 +157,8 @@ def create_worker_session(self):
156157
session = MpiPoolSession(n_workers=2)
157158
return session
158159

160+
@pytest.mark.gpu2
161+
@skip_single_gpu
159162
def test_create_executor(self):
160163
futures = self.session.submit(
161164
TestRpcWorkerBaseTP2.create_executor,

tests/unittest/executor/test_rpc.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -238,10 +238,10 @@ def slow_method(self):
238238
with pytest.raises(RPCError) as exc_info:
239239
client.slow_method().remote(timeout=0.5)
240240

241-
error = exc_info.value
242-
# Should be either a timeout error or RPC error indicating timeout
243-
assert "timed out" in str(
244-
error).lower() or "timeout" in str(error).lower()
241+
error = exc_info.value
242+
# Should be either a timeout error or RPC error indicating timeout
243+
assert "timed out" in str(error).lower() or "timeout" in str(
244+
error).lower()
245245

246246
def test_method_not_found_error(self):
247247
"""Test that calling non-existent methods returns proper error."""

tests/unittest/executor/test_rpc_proxy.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# isort: off
1515
sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/..")
1616
from utils.llm_data import llm_models_root
17-
from utils.util import similar
17+
from utils.util import similar, skip_single_gpu
1818
# isort: on
1919

2020
model_path = llm_models_root() / "llama-models-v2/TinyLlama-1.1B-Chat-v1.0"
@@ -78,6 +78,8 @@ def test_tp1(self, num_reqs):
7878
assert isinstance(kv_cache_events, list)
7979

8080
@pytest.mark.parametrize("num_reqs", [1, 10])
81+
@skip_single_gpu
82+
@pytest.mark.gpu2
8183
def test_tp2(self, num_reqs):
8284
tokenizer = TransformersTokenizer.from_pretrained(model_path)
8385
prompt = "A B C D"

0 commit comments

Comments
 (0)