@@ -151,8 +151,18 @@ async def send_request_to_service(client_info: dict, endpoint: str,
151151 Send a request to a service using a client from the pool.
152152 """
153153 req_data = req_data .copy ()
154- req_data ['do_remote_decode' ] = True
154+ req_data ['kv_transfer_params' ] = {
155+ "do_remote_decode" : True ,
156+ "do_remote_prefill" : False ,
157+ "remote_engine_id" : None ,
158+ "remote_block_ids" : None ,
159+ "remote_host" : None ,
160+ "remote_port" : None
161+ }
155162 req_data ["stream" ] = False
163+ req_data ["max_tokens" ] = 1
164+ if "stream_options" in req_data :
165+ del req_data ["stream_options" ]
156166 headers = {
157167 "Authorization" : f"Bearer { os .environ .get ('OPENAI_API_KEY' )} " ,
158168 "X-Request-Id" : request_id
@@ -167,22 +177,14 @@ async def send_request_to_service(client_info: dict, endpoint: str,
167177
168178
169179async def stream_service_response (client_info : dict , endpoint : str ,
170- req_data : dict , remote_block_ids : list [int ],
171- remote_engine_id : str , remote_host : str ,
172- remote_port : int , request_id : str ):
180+ req_data : dict , request_id : str ):
173181 """
174182 Asynchronously stream response from a service using a client from the pool.
175183 """
176184 headers = {
177185 "Authorization" : f"Bearer { os .environ .get ('OPENAI_API_KEY' )} " ,
178186 "X-Request-Id" : request_id
179187 }
180- req_data = req_data .copy ()
181- req_data ['do_remote_prefill' ] = True
182- req_data ["remote_block_ids" ] = remote_block_ids
183- req_data ['remote_engine_id' ] = remote_engine_id
184- req_data ["remote_host" ] = remote_host
185- req_data ["remote_port" ] = remote_port
186188
187189 async with client_info ['client' ].stream ("POST" ,
188190 endpoint ,
@@ -209,10 +211,9 @@ async def handle_completions(request: Request):
209211
210212 # Extract the needed fields
211213 response_json = response .json ()
212- remote_block_ids = response_json .get ('remote_block_ids' , [])
213- remote_engine_id = response_json .get ('remote_engine_id' , '' )
214- remote_host = response_json .get ('remote_host' , '' )
215- remote_port = response_json .get ('remote_port' , 0 )
214+ kv_transfer_params = response_json .get ('kv_transfer_params' , {})
215+ if kv_transfer_params :
216+ req_data ["kv_transfer_params" ] = kv_transfer_params
216217
217218 # Get the next decode client in round-robin fashion
218219 decode_client_info = get_next_client (request .app , 'decode' )
@@ -221,15 +222,10 @@ async def handle_completions(request: Request):
221222
222223 # Stream response from decode service
223224 async def generate_stream ():
224- async for chunk in stream_service_response (
225- decode_client_info ,
226- "/completions" ,
227- req_data ,
228- remote_block_ids = remote_block_ids ,
229- remote_engine_id = remote_engine_id ,
230- remote_host = remote_host ,
231- remote_port = remote_port ,
232- request_id = request_id ):
225+ async for chunk in stream_service_response (decode_client_info ,
226+ "/completions" ,
227+ req_data ,
228+ request_id = request_id ):
233229 yield chunk
234230
235231 return StreamingResponse (generate_stream (),
0 commit comments