@@ -67,6 +67,41 @@ def __exit__(self, exc_type, exc_val, exc_tb):
6767 return False
6868
6969
70+ def check_port_available (port : int ) -> int :
71+ import socket
72+ try :
73+ with socket .socket (socket .AF_INET , socket .SOCK_STREAM ) as s :
74+ s .bind (('localhost' , port ))
75+ return port
76+ except socket .error :
77+ # find a free port
78+ sock = socket .socket ()
79+ sock .bind (('' , 0 ))
80+ return sock .getsockname ()[1 ]
81+
82+
83+ def revise_disaggregated_server_config_urls_with_free_ports (
84+ disaggregated_server_config : Dict [str , Any ]) -> Dict [str , Any ]:
85+ disaggregated_server_config ['port' ] = check_port_available (
86+ disaggregated_server_config ['port' ])
87+ ctx_urls = disaggregated_server_config ["context_servers" ]["urls" ]
88+ gen_urls = disaggregated_server_config ["generation_servers" ]["urls" ]
89+
90+ new_ctx_urls = []
91+ new_gen_urls = []
92+ for url in ctx_urls :
93+ port = check_port_available (int (url .split (":" )[1 ]))
94+ new_ctx_urls .append (f"localhost:{ port } " )
95+ for url in gen_urls :
96+ port = check_port_available (int (url .split (":" )[1 ]))
97+ new_gen_urls .append (f"localhost:{ port } " )
98+
99+ disaggregated_server_config ["context_servers" ]["urls" ] = new_ctx_urls
100+ disaggregated_server_config ["generation_servers" ]["urls" ] = new_gen_urls
101+
102+ return disaggregated_server_config
103+
104+
70105@contextlib .contextmanager
71106def launch_disaggregated_llm (
72107 disaggregated_server_config : Dict [str , Any ],
@@ -87,6 +122,9 @@ def launch_disaggregated_llm(
87122 f"Using unified tp parameter for testing is not recommended. Please use server configs instead."
88123 )
89124
125+ disaggregated_server_config = revise_disaggregated_server_config_urls_with_free_ports (
126+ disaggregated_server_config )
127+
90128 with open (disaggregated_serving_config_path , "w" ) as f :
91129 yaml .dump (disaggregated_server_config , f )
92130 ctx_server_config_path = os .path .join (temp_dir .name ,
@@ -138,6 +176,7 @@ def launch_disaggregated_llm(
138176 ctx_urls = disaggregated_server_config ["context_servers" ]["urls" ]
139177 gen_urls = disaggregated_server_config ["generation_servers" ]["urls" ]
140178
179+ serve_port = disaggregated_server_config ["port" ]
141180 ctx_ports = [int (url .split (":" )[1 ]) for url in ctx_urls ]
142181 gen_ports = [int (url .split (":" )[1 ]) for url in gen_urls ]
143182
@@ -236,14 +275,14 @@ def multi_popen(server_configs, server_name="", enable_redirect_log=False):
236275 )
237276 try :
238277 print ("Checking health endpoint" )
239- response = requests .get ("http://localhost:8000 /health" )
278+ response = requests .get (f "http://localhost:{ serve_port } /health" )
240279 if response .status_code == 200 :
241280 break
242281 except requests .exceptions .ConnectionError :
243282 continue
244283
245284 client = openai .OpenAI (api_key = "1234567890" ,
246- base_url = f"http://localhost:8000 /v1" ,
285+ base_url = f"http://localhost:{ serve_port } /v1" ,
247286 timeout = 1800000 )
248287
249288 def send_request (prompt : str , sampling_params : SamplingParams ,
0 commit comments