Skip to content

Commit 3eae58c

Browse files
authored
Add disaggregated unittest (NVIDIA#4899)
Signed-off-by: Shunkang <[email protected]>
1 parent a152635 commit 3eae58c

File tree

5 files changed

+138
-243
lines changed

5 files changed

+138
-243
lines changed

tensorrt_llm/commands/serve.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,7 @@ def serve(model: str, tokenizer: Optional[str], host: str, port: int,
292292
metadata_server_config_file)
293293

294294
if metadata_server_cfg is not None:
295+
assert server_role is not None, "server_role is required when metadata_server_cfg is provided"
295296
try:
296297
server_role = ServerRole[server_role.upper()]
297298
except ValueError:

tests/integration/test_lists/test-db/l0_h100.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ l0_h100:
1919
- unittest/_torch -k "modeling_llama"
2020
- unittest/_torch/modeling -k "modeling_mixtral"
2121
- unittest/_torch/modeling -k "modeling_nemotron"
22+
- unittest/disaggregated/test_disagg_utils.py
23+
- unittest/disaggregated/test_router.py
24+
- unittest/disaggregated/test_remoteDictionary.py
2225
- accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_auto_dtype
2326
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16[attn_backend=TRTLLM-torch_compile=False]
2427
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16[attn_backend=TRTLLM-torch_compile=True]

tests/unittest/disaggregated/test_disagg_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@ def verify_disagg_config(config: DisaggServerConfig):
5757
assert config.ctx_router_config.type == "round_robin"
5858
assert config.gen_router_config.type == "load_balancing"
5959
assert len(config.server_configs) == 3
60-
assert config.condition is None
6160

6261

6362
def test_parse_disagg_config_file(sample_yaml_file):

tests/unittest/disaggregated/test_remoteDictionary.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,46 @@
1+
import os
2+
import signal
3+
import subprocess
4+
import time
15
import unittest
26

37
import etcd3
48

59
from tensorrt_llm.serve.metadata_server import EtcdDictionary
610

711

12+
def start_etcd_server():
13+
# Command to start etcd
14+
etcd_cmd = ["etcd"]
15+
16+
# Start etcd in background
17+
process = subprocess.Popen(
18+
etcd_cmd,
19+
stdout=subprocess.PIPE,
20+
stderr=subprocess.PIPE,
21+
preexec_fn=os.setsid) # This makes it run in a new process group
22+
23+
# Wait a bit for etcd to start
24+
time.sleep(5)
25+
26+
return process
27+
28+
29+
def stop_etcd_server(process):
30+
# Kill the process group
31+
os.killpg(os.getpgid(process.pid), signal.SIGTERM)
32+
process.wait()
33+
34+
835
class TestEtcdDictionary(unittest.TestCase):
936

1037
def setUp(self):
38+
# Set the protocol buffers implementation to python
39+
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
40+
41+
# Start etcd server
42+
self.etcd_process = start_etcd_server()
43+
1144
# Setup etcd connection parameters
1245
self.host = "localhost"
1346
self.port = 2379
@@ -25,6 +58,12 @@ def tearDown(self):
2558
# Clean up test keys after each test
2659
self._cleanup_test_keys()
2760

61+
# Stop etcd server
62+
stop_etcd_server(self.etcd_process)
63+
64+
# Unset the protocol buffers implementation
65+
del os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"]
66+
2867
def _cleanup_test_keys(self):
2968
# Helper method to remove test keys
3069
test_keys = [
@@ -47,6 +86,8 @@ def test_put_and_get(self):
4786
# Assert
4887
self.assertEqual(value.decode('utf-8'), test_value)
4988

89+
self._cleanup_test_keys()
90+
5091
def test_remove(self):
5192
# Test removing a key
5293
test_key = "trtllm/1/test_key2"
@@ -63,6 +104,8 @@ def test_remove(self):
63104
self.assertIsNone(
64105
result[0]) # etcd3 returns (None, None) when key doesn't exist
65106

107+
self._cleanup_test_keys()
108+
66109
def test_keys(self):
67110
# Test listing all keys
68111
test_data = {
@@ -85,6 +128,8 @@ def test_keys(self):
85128
extract_keys = set(keys)
86129
self.assertEqual(prefix_keys, extract_keys)
87130

131+
self._cleanup_test_keys()
132+
88133
def test_get_nonexistent_key(self):
89134
# Test getting a key that doesn't exist
90135
result, _ = self.etcd_dict.get("nonexistent_key")
@@ -108,6 +153,8 @@ def test_put_update_existing(self):
108153
# Assert
109154
self.assertEqual(value.decode('utf-8'), updated_value)
110155

156+
self._cleanup_test_keys()
157+
111158

112159
if __name__ == '__main__':
113160
unittest.main()

0 commit comments

Comments
 (0)