77
88from anyio import Lock
99from functools import partial
10- from typing import Iterator , List , Optional , Union , Dict
10+ from typing import List , Optional , Union , Dict
1111
1212import llama_cpp
1313
@@ -155,34 +155,71 @@ def create_app(
155155 return app
156156
157157
158+ def prepare_request_resources (
159+ body : CreateCompletionRequest | CreateChatCompletionRequest ,
160+ llama_proxy : LlamaProxy ,
161+ body_model : str | None ,
162+ kwargs ,
163+ ) -> llama_cpp .Llama :
164+ if llama_proxy is None :
165+ raise HTTPException (
166+ status_code = status .HTTP_503_SERVICE_UNAVAILABLE ,
167+ detail = "Service is not available" ,
168+ )
169+ llama = llama_proxy (body_model )
170+ if body .logit_bias is not None :
171+ kwargs ["logit_bias" ] = (
172+ _logit_bias_tokens_to_input_ids (llama , body .logit_bias )
173+ if body .logit_bias_type == "tokens"
174+ else body .logit_bias
175+ )
176+
177+ if body .grammar is not None :
178+ kwargs ["grammar" ] = llama_cpp .LlamaGrammar .from_string (body .grammar )
179+
180+ if body .min_tokens > 0 :
181+ _min_tokens_logits_processor = llama_cpp .LogitsProcessorList (
182+ [llama_cpp .MinTokensLogitsProcessor (body .min_tokens , llama .token_eos ())]
183+ )
184+ if "logits_processor" not in kwargs :
185+ kwargs ["logits_processor" ] = _min_tokens_logits_processor
186+ else :
187+ kwargs ["logits_processor" ].extend (_min_tokens_logits_processor )
188+ return llama
189+
190+
158191async def get_event_publisher (
159192 request : Request ,
160193 inner_send_chan : MemoryObjectSendStream [typing .Any ],
161- iterator : Iterator [typing .Any ],
162- on_complete : typing .Optional [typing .Callable [[], typing .Awaitable [None ]]] = None ,
194+ body : CreateCompletionRequest | CreateChatCompletionRequest ,
195+ body_model : str | None ,
196+ llama_call ,
197+ kwargs ,
163198):
164199 server_settings = next (get_server_settings ())
165200 interrupt_requests = (
166201 server_settings .interrupt_requests if server_settings else False
167202 )
168- async with inner_send_chan :
169- try :
170- async for chunk in iterate_in_threadpool (iterator ):
171- await inner_send_chan .send (dict (data = json .dumps (chunk )))
172- if await request .is_disconnected ():
173- raise anyio .get_cancelled_exc_class ()()
174- if interrupt_requests and llama_outer_lock .locked ():
175- await inner_send_chan .send (dict (data = "[DONE]" ))
176- raise anyio .get_cancelled_exc_class ()()
177- await inner_send_chan .send (dict (data = "[DONE]" ))
178- except anyio .get_cancelled_exc_class () as e :
179- print ("disconnected" )
180- with anyio .move_on_after (1 , shield = True ):
181- print (f"Disconnected from client (via refresh/close) { request .client } " )
182- raise e
183- finally :
184- if on_complete :
185- await on_complete ()
203+ async with contextlib .asynccontextmanager (get_llama_proxy )() as llama_proxy :
204+ llama = prepare_request_resources (body , llama_proxy , body_model , kwargs )
205+ async with inner_send_chan :
206+ try :
207+ iterator = await run_in_threadpool (llama_call , llama , ** kwargs )
208+ async for chunk in iterate_in_threadpool (iterator ):
209+ await inner_send_chan .send (dict (data = json .dumps (chunk )))
210+ if await request .is_disconnected ():
211+ raise anyio .get_cancelled_exc_class ()()
212+ if interrupt_requests and llama_outer_lock .locked ():
213+ await inner_send_chan .send (dict (data = "[DONE]" ))
214+ raise anyio .get_cancelled_exc_class ()()
215+ await inner_send_chan .send (dict (data = "[DONE]" ))
216+ except anyio .get_cancelled_exc_class () as e :
217+ print ("disconnected" )
218+ with anyio .move_on_after (1 , shield = True ):
219+ print (
220+ f"Disconnected from client (via refresh/close) { request .client } "
221+ )
222+ raise e
186223
187224
188225def _logit_bias_tokens_to_input_ids (
@@ -267,18 +304,11 @@ async def create_completion(
267304 request : Request ,
268305 body : CreateCompletionRequest ,
269306) -> llama_cpp .Completion :
270- exit_stack = contextlib .AsyncExitStack ()
271- llama_proxy = await exit_stack .enter_async_context (contextlib .asynccontextmanager (get_llama_proxy )())
272- if llama_proxy is None :
273- raise HTTPException (
274- status_code = status .HTTP_503_SERVICE_UNAVAILABLE ,
275- detail = "Service is not available" ,
276- )
277307 if isinstance (body .prompt , list ):
278308 assert len (body .prompt ) <= 1
279309 body .prompt = body .prompt [0 ] if len (body .prompt ) > 0 else ""
280310
281- llama = llama_proxy (
311+ body_model = (
282312 body .model
283313 if request .url .path != "/v1/engines/copilot-codex/completions"
284314 else "copilot-codex"
@@ -293,60 +323,38 @@ async def create_completion(
293323 }
294324 kwargs = body .model_dump (exclude = exclude )
295325
296- if body .logit_bias is not None :
297- kwargs ["logit_bias" ] = (
298- _logit_bias_tokens_to_input_ids (llama , body .logit_bias )
299- if body .logit_bias_type == "tokens"
300- else body .logit_bias
301- )
302-
303- if body .grammar is not None :
304- kwargs ["grammar" ] = llama_cpp .LlamaGrammar .from_string (body .grammar )
305-
306- if body .min_tokens > 0 :
307- _min_tokens_logits_processor = llama_cpp .LogitsProcessorList (
308- [llama_cpp .MinTokensLogitsProcessor (body .min_tokens , llama .token_eos ())]
309- )
310- if "logits_processor" not in kwargs :
311- kwargs ["logits_processor" ] = _min_tokens_logits_processor
312- else :
313- kwargs ["logits_processor" ].extend (_min_tokens_logits_processor )
314-
315- try :
316- iterator_or_completion : Union [
317- llama_cpp .CreateCompletionResponse ,
318- Iterator [llama_cpp .CreateCompletionStreamResponse ],
319- ] = await run_in_threadpool (llama , ** kwargs )
320- except Exception as err :
321- await exit_stack .aclose ()
322- raise err
323-
324- if isinstance (iterator_or_completion , Iterator ):
325- # EAFP: It's easier to ask for forgiveness than permission
326- first_response = await run_in_threadpool (next , iterator_or_completion )
327-
328- # If no exception was raised from first_response, we can assume that
329- # the iterator is valid and we can use it to stream the response.
330- def iterator () -> Iterator [llama_cpp .CreateCompletionStreamResponse ]:
331- yield first_response
332- yield from iterator_or_completion
333-
326+ # handle streaming request
327+ if kwargs .get ("stream" , False ):
334328 send_chan , recv_chan = anyio .create_memory_object_stream (10 )
335329 return EventSourceResponse (
336330 recv_chan ,
337331 data_sender_callable = partial ( # type: ignore
338332 get_event_publisher ,
339333 request = request ,
340334 inner_send_chan = send_chan ,
341- iterator = iterator (),
342- on_complete = exit_stack .aclose ,
335+ body = body ,
336+ body_model = body_model ,
337+ llama_call = llama_cpp .Llama .__call__ ,
338+ kwargs = kwargs ,
343339 ),
344340 sep = "\n " ,
345341 ping_message_factory = _ping_message_factory ,
346342 )
347- else :
348- await exit_stack .aclose ()
349- return iterator_or_completion
343+
344+ # handle regular request
345+ async with contextlib .asynccontextmanager (get_llama_proxy )() as llama_proxy :
346+ llama = prepare_request_resources (body , llama_proxy , body_model , kwargs )
347+
348+ if await request .is_disconnected ():
349+ print (
350+ f"Disconnected from client (via refresh/close) before llm invoked { request .client } "
351+ )
352+ raise HTTPException (
353+ status_code = status .HTTP_400_BAD_REQUEST ,
354+ detail = "Client closed request" ,
355+ )
356+
357+ return await run_in_threadpool (llama , ** kwargs )
350358
351359
352360@router .post (
@@ -474,74 +482,48 @@ async def create_chat_completion(
474482 # where the dependency is cleaned up before a StreamingResponse
475483 # is complete.
476484 # https://github.com/tiangolo/fastapi/issues/11143
477- exit_stack = contextlib .AsyncExitStack ()
478- llama_proxy = await exit_stack .enter_async_context (contextlib .asynccontextmanager (get_llama_proxy )())
479- if llama_proxy is None :
480- raise HTTPException (
481- status_code = status .HTTP_503_SERVICE_UNAVAILABLE ,
482- detail = "Service is not available" ,
483- )
485+
486+ body_model = body .model
484487 exclude = {
485488 "n" ,
486489 "logit_bias_type" ,
487490 "user" ,
488491 "min_tokens" ,
489492 }
490493 kwargs = body .model_dump (exclude = exclude )
491- llama = llama_proxy (body .model )
492- if body .logit_bias is not None :
493- kwargs ["logit_bias" ] = (
494- _logit_bias_tokens_to_input_ids (llama , body .logit_bias )
495- if body .logit_bias_type == "tokens"
496- else body .logit_bias
497- )
498-
499- if body .grammar is not None :
500- kwargs ["grammar" ] = llama_cpp .LlamaGrammar .from_string (body .grammar )
501-
502- if body .min_tokens > 0 :
503- _min_tokens_logits_processor = llama_cpp .LogitsProcessorList (
504- [llama_cpp .MinTokensLogitsProcessor (body .min_tokens , llama .token_eos ())]
505- )
506- if "logits_processor" not in kwargs :
507- kwargs ["logits_processor" ] = _min_tokens_logits_processor
508- else :
509- kwargs ["logits_processor" ].extend (_min_tokens_logits_processor )
510-
511- try :
512- iterator_or_completion : Union [
513- llama_cpp .ChatCompletion , Iterator [llama_cpp .ChatCompletionChunk ]
514- ] = await run_in_threadpool (llama .create_chat_completion , ** kwargs )
515- except Exception as err :
516- await exit_stack .aclose ()
517- raise err
518-
519- if isinstance (iterator_or_completion , Iterator ):
520- # EAFP: It's easier to ask for forgiveness than permission
521- first_response = await run_in_threadpool (next , iterator_or_completion )
522-
523- # If no exception was raised from first_response, we can assume that
524- # the iterator is valid and we can use it to stream the response.
525- def iterator () -> Iterator [llama_cpp .ChatCompletionChunk ]:
526- yield first_response
527- yield from iterator_or_completion
528494
495+ # handle streaming request
496+ if kwargs .get ("stream" , False ):
529497 send_chan , recv_chan = anyio .create_memory_object_stream (10 )
530498 return EventSourceResponse (
531499 recv_chan ,
532500 data_sender_callable = partial ( # type: ignore
533501 get_event_publisher ,
534502 request = request ,
535503 inner_send_chan = send_chan ,
536- iterator = iterator (),
537- on_complete = exit_stack .aclose ,
504+ body = body ,
505+ body_model = body_model ,
506+ llama_call = llama_cpp .Llama .create_chat_completion ,
507+ kwargs = kwargs ,
538508 ),
539509 sep = "\n " ,
540510 ping_message_factory = _ping_message_factory ,
541511 )
542- else :
543- await exit_stack .aclose ()
544- return iterator_or_completion
512+
513+ # handle regular request
514+ async with contextlib .asynccontextmanager (get_llama_proxy )() as llama_proxy :
515+ llama = prepare_request_resources (body , llama_proxy , body_model , kwargs )
516+
517+ if await request .is_disconnected ():
518+ print (
519+ f"Disconnected from client (via refresh/close) before llm invoked { request .client } "
520+ )
521+ raise HTTPException (
522+ status_code = status .HTTP_400_BAD_REQUEST ,
523+ detail = "Client closed request" ,
524+ )
525+
526+ return await run_in_threadpool (llama .create_chat_completion , ** kwargs )
545527
546528
547529@router .get (
0 commit comments