1+ import os
2+ import signal
3+ import subprocess
4+ import time
15import unittest
26
37import etcd3
48
59from tensorrt_llm .serve .metadata_server import EtcdDictionary
610
711
12+ def start_etcd_server ():
13+ # Command to start etcd
14+ etcd_cmd = ["etcd" ]
15+
16+ # Start etcd in background
17+ process = subprocess .Popen (
18+ etcd_cmd ,
19+ stdout = subprocess .PIPE ,
20+ stderr = subprocess .PIPE ,
21+ preexec_fn = os .setsid ) # This makes it run in a new process group
22+
23+ # Wait a bit for etcd to start
24+ time .sleep (5 )
25+
26+ return process
27+
28+
29+ def stop_etcd_server (process ):
30+ # Kill the process group
31+ os .killpg (os .getpgid (process .pid ), signal .SIGTERM )
32+ process .wait ()
33+
34+
835class TestEtcdDictionary (unittest .TestCase ):
936
1037 def setUp (self ):
38+ # Set the protocol buffers implementation to python
39+ os .environ ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION" ] = "python"
40+
41+ # Start etcd server
42+ self .etcd_process = start_etcd_server ()
43+
1144 # Setup etcd connection parameters
1245 self .host = "localhost"
1346 self .port = 2379
@@ -25,6 +58,12 @@ def tearDown(self):
2558 # Clean up test keys after each test
2659 self ._cleanup_test_keys ()
2760
61+ # Stop etcd server
62+ stop_etcd_server (self .etcd_process )
63+
64+ # Unset the protocol buffers implementation
65+ del os .environ ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION" ]
66+
2867 def _cleanup_test_keys (self ):
2968 # Helper method to remove test keys
3069 test_keys = [
@@ -47,6 +86,8 @@ def test_put_and_get(self):
4786 # Assert
4887 self .assertEqual (value .decode ('utf-8' ), test_value )
4988
89+ self ._cleanup_test_keys ()
90+
5091 def test_remove (self ):
5192 # Test removing a key
5293 test_key = "trtllm/1/test_key2"
@@ -63,6 +104,8 @@ def test_remove(self):
63104 self .assertIsNone (
64105 result [0 ]) # etcd3 returns (None, None) when key doesn't exist
65106
107+ self ._cleanup_test_keys ()
108+
66109 def test_keys (self ):
67110 # Test listing all keys
68111 test_data = {
@@ -85,6 +128,8 @@ def test_keys(self):
85128 extract_keys = set (keys )
86129 self .assertEqual (prefix_keys , extract_keys )
87130
131+ self ._cleanup_test_keys ()
132+
88133 def test_get_nonexistent_key (self ):
89134 # Test getting a key that doesn't exist
90135 result , _ = self .etcd_dict .get ("nonexistent_key" )
@@ -108,6 +153,8 @@ def test_put_update_existing(self):
108153 # Assert
109154 self .assertEqual (value .decode ('utf-8' ), updated_value )
110155
156+ self ._cleanup_test_keys ()
157+
111158
112159if __name__ == '__main__' :
113160 unittest .main ()
0 commit comments