@@ -154,19 +154,56 @@ def create_app(
154154
155155 return app
156156
157+ def prepare_request_resources (
158+ body : CreateCompletionRequest | CreateChatCompletionRequest ,
159+ llama_proxy : LlamaProxy ,
160+ body_model : str ,
161+ kwargs ) -> llama_cpp .Llama :
162+ if llama_proxy is None :
163+ raise HTTPException (
164+ status_code = status .HTTP_503_SERVICE_UNAVAILABLE ,
165+ detail = "Service is not available" ,
166+ )
167+ llama = llama_proxy (body_model )
168+ if body .logit_bias is not None :
169+ kwargs ["logit_bias" ] = (
170+ _logit_bias_tokens_to_input_ids (llama , body .logit_bias )
171+ if body .logit_bias_type == "tokens"
172+ else body .logit_bias
173+ )
174+
175+ if body .grammar is not None :
176+ kwargs ["grammar" ] = llama_cpp .LlamaGrammar .from_string (body .grammar )
177+
178+ if body .min_tokens > 0 :
179+ _min_tokens_logits_processor = llama_cpp .LogitsProcessorList (
180+ [llama_cpp .MinTokensLogitsProcessor (body .min_tokens , llama .token_eos ())]
181+ )
182+ if "logits_processor" not in kwargs :
183+ kwargs ["logits_processor" ] = _min_tokens_logits_processor
184+ else :
185+ kwargs ["logits_processor" ].extend (_min_tokens_logits_processor )
186+ return llama
187+
157188
158189async def get_event_publisher (
159190 request : Request ,
160191 inner_send_chan : MemoryObjectSendStream [typing .Any ],
161- iterator : Iterator [typing .Any ],
162- on_complete : typing .Optional [typing .Callable [[], typing .Awaitable [None ]]] = None ,
192+ body : CreateCompletionRequest | CreateChatCompletionRequest ,
193+ body_model : str ,
194+ llama_call ,
195+ kwargs ,
163196):
164197 server_settings = next (get_server_settings ())
165198 interrupt_requests = (
166199 server_settings .interrupt_requests if server_settings else False
167200 )
201+ exit_stack = contextlib .AsyncExitStack ()
202+ llama_proxy : LlamaProxy = await exit_stack .enter_async_context (contextlib .asynccontextmanager (get_llama_proxy )())
203+ llama = prepare_request_resources (body , llama_proxy , body_model , kwargs )
168204 async with inner_send_chan :
169205 try :
206+ iterator = await run_in_threadpool (llama_call , llama , ** kwargs )
170207 async for chunk in iterate_in_threadpool (iterator ):
171208 await inner_send_chan .send (dict (data = json .dumps (chunk )))
172209 if await request .is_disconnected ():
@@ -181,8 +218,7 @@ async def get_event_publisher(
181218 print (f"Disconnected from client (via refresh/close) { request .client } " )
182219 raise e
183220 finally :
184- if on_complete :
185- await on_complete ()
221+ await exit_stack .aclose ()
186222
187223
188224def _logit_bias_tokens_to_input_ids (
@@ -267,18 +303,11 @@ async def create_completion(
267303 request : Request ,
268304 body : CreateCompletionRequest ,
269305) -> llama_cpp .Completion :
270- exit_stack = contextlib .AsyncExitStack ()
271- llama_proxy = await exit_stack .enter_async_context (contextlib .asynccontextmanager (get_llama_proxy )())
272- if llama_proxy is None :
273- raise HTTPException (
274- status_code = status .HTTP_503_SERVICE_UNAVAILABLE ,
275- detail = "Service is not available" ,
276- )
277306 if isinstance (body .prompt , list ):
278307 assert len (body .prompt ) <= 1
279308 body .prompt = body .prompt [0 ] if len (body .prompt ) > 0 else ""
280309
281- llama = llama_proxy (
310+ body_model = (
282311 body .model
283312 if request .url .path != "/v1/engines/copilot-codex/completions"
284313 else "copilot-codex"
@@ -293,60 +322,42 @@ async def create_completion(
293322 }
294323 kwargs = body .model_dump (exclude = exclude )
295324
296- if body .logit_bias is not None :
297- kwargs ["logit_bias" ] = (
298- _logit_bias_tokens_to_input_ids (llama , body .logit_bias )
299- if body .logit_bias_type == "tokens"
300- else body .logit_bias
301- )
302-
303- if body .grammar is not None :
304- kwargs ["grammar" ] = llama_cpp .LlamaGrammar .from_string (body .grammar )
305-
306- if body .min_tokens > 0 :
307- _min_tokens_logits_processor = llama_cpp .LogitsProcessorList (
308- [llama_cpp .MinTokensLogitsProcessor (body .min_tokens , llama .token_eos ())]
309- )
310- if "logits_processor" not in kwargs :
311- kwargs ["logits_processor" ] = _min_tokens_logits_processor
312- else :
313- kwargs ["logits_processor" ].extend (_min_tokens_logits_processor )
314-
315- try :
316- iterator_or_completion : Union [
317- llama_cpp .CreateCompletionResponse ,
318- Iterator [llama_cpp .CreateCompletionStreamResponse ],
319- ] = await run_in_threadpool (llama , ** kwargs )
320- except Exception as err :
321- await exit_stack .aclose ()
322- raise err
323-
324- if isinstance (iterator_or_completion , Iterator ):
325- # EAFP: It's easier to ask for forgiveness than permission
326- first_response = await run_in_threadpool (next , iterator_or_completion )
327-
328- # If no exception was raised from first_response, we can assume that
329- # the iterator is valid and we can use it to stream the response.
330- def iterator () -> Iterator [llama_cpp .CreateCompletionStreamResponse ]:
331- yield first_response
332- yield from iterator_or_completion
333-
325+ # handle streaming request
326+ if kwargs .get ("stream" , False ):
334327 send_chan , recv_chan = anyio .create_memory_object_stream (10 )
335328 return EventSourceResponse (
336329 recv_chan ,
337330 data_sender_callable = partial ( # type: ignore
338331 get_event_publisher ,
339332 request = request ,
340333 inner_send_chan = send_chan ,
341- iterator = iterator (),
342- on_complete = exit_stack .aclose ,
334+ body = body ,
335+ body_model = body_model ,
336+ llama_call = llama_cpp .Llama .__call__ ,
337+ kwargs = kwargs ,
343338 ),
344339 sep = "\n " ,
345340 ping_message_factory = _ping_message_factory ,
346341 )
347- else :
342+
343+ # handle regular request
344+ exit_stack = contextlib .AsyncExitStack ()
345+ llama_proxy : LlamaProxy = await exit_stack .enter_async_context (contextlib .asynccontextmanager (get_llama_proxy )())
346+ llama = prepare_request_resources (body , llama_proxy , body_model , kwargs )
347+
348+ if await request .is_disconnected ():
349+ print (f"Disconnected from client (via refresh/close) before llm invoked { request .client } " )
348350 await exit_stack .aclose ()
349- return iterator_or_completion
351+ raise HTTPException (
352+ status_code = status .HTTP_400_BAD_REQUEST ,
353+ detail = "Client closed request" ,
354+ )
355+
356+ try :
357+ completion : llama_cpp .CreateCompletionResponse = await run_in_threadpool (llama , ** kwargs )
358+ finally :
359+ await exit_stack .aclose ()
360+ return completion
350361
351362
352363@router .post (
@@ -474,74 +485,52 @@ async def create_chat_completion(
474485 # where the dependency is cleaned up before a StreamingResponse
475486 # is complete.
476487 # https://github.com/tiangolo/fastapi/issues/11143
477- exit_stack = contextlib .AsyncExitStack ()
478- llama_proxy = await exit_stack .enter_async_context (contextlib .asynccontextmanager (get_llama_proxy )())
479- if llama_proxy is None :
480- raise HTTPException (
481- status_code = status .HTTP_503_SERVICE_UNAVAILABLE ,
482- detail = "Service is not available" ,
483- )
488+
489+ body_model = body .model
484490 exclude = {
485491 "n" ,
486492 "logit_bias_type" ,
487493 "user" ,
488494 "min_tokens" ,
489495 }
490496 kwargs = body .model_dump (exclude = exclude )
491- llama = llama_proxy (body .model )
492- if body .logit_bias is not None :
493- kwargs ["logit_bias" ] = (
494- _logit_bias_tokens_to_input_ids (llama , body .logit_bias )
495- if body .logit_bias_type == "tokens"
496- else body .logit_bias
497- )
498-
499- if body .grammar is not None :
500- kwargs ["grammar" ] = llama_cpp .LlamaGrammar .from_string (body .grammar )
501-
502- if body .min_tokens > 0 :
503- _min_tokens_logits_processor = llama_cpp .LogitsProcessorList (
504- [llama_cpp .MinTokensLogitsProcessor (body .min_tokens , llama .token_eos ())]
505- )
506- if "logits_processor" not in kwargs :
507- kwargs ["logits_processor" ] = _min_tokens_logits_processor
508- else :
509- kwargs ["logits_processor" ].extend (_min_tokens_logits_processor )
510-
511- try :
512- iterator_or_completion : Union [
513- llama_cpp .ChatCompletion , Iterator [llama_cpp .ChatCompletionChunk ]
514- ] = await run_in_threadpool (llama .create_chat_completion , ** kwargs )
515- except Exception as err :
516- await exit_stack .aclose ()
517- raise err
518-
519- if isinstance (iterator_or_completion , Iterator ):
520- # EAFP: It's easier to ask for forgiveness than permission
521- first_response = await run_in_threadpool (next , iterator_or_completion )
522-
523- # If no exception was raised from first_response, we can assume that
524- # the iterator is valid and we can use it to stream the response.
525- def iterator () -> Iterator [llama_cpp .ChatCompletionChunk ]:
526- yield first_response
527- yield from iterator_or_completion
528497
498+ # handle streaming request
499+ if kwargs .get ("stream" , False ):
529500 send_chan , recv_chan = anyio .create_memory_object_stream (10 )
530501 return EventSourceResponse (
531502 recv_chan ,
532503 data_sender_callable = partial ( # type: ignore
533504 get_event_publisher ,
534505 request = request ,
535506 inner_send_chan = send_chan ,
536- iterator = iterator (),
537- on_complete = exit_stack .aclose ,
507+ body = body ,
508+ body_model = body_model ,
509+ llama_call = llama_cpp .Llama .create_chat_completion ,
510+ kwargs = kwargs ,
538511 ),
539512 sep = "\n " ,
540513 ping_message_factory = _ping_message_factory ,
541514 )
542- else :
515+
516+ # handle regular request
517+ exit_stack = contextlib .AsyncExitStack ()
518+ llama_proxy : LlamaProxy = await exit_stack .enter_async_context (contextlib .asynccontextmanager (get_llama_proxy )())
519+ llama = prepare_request_resources (body , llama_proxy , body_model , kwargs )
520+
521+ if await request .is_disconnected ():
522+ print (f"Disconnected from client (via refresh/close) before llm invoked { request .client } " )
523+ await exit_stack .aclose ()
524+ raise HTTPException (
525+ status_code = status .HTTP_400_BAD_REQUEST ,
526+ detail = "Client closed request" ,
527+ )
528+
529+ try :
530+ completion : llama_cpp .ChatCompletion = await run_in_threadpool (llama .create_chat_completion , ** kwargs )
531+ finally :
543532 await exit_stack .aclose ()
544- return iterator_or_completion
533+ return completion
545534
546535
547536@router .get (
0 commit comments