11import json
22import multiprocessing
33from threading import Lock
4- from typing import List , Optional , Union , Iterator , Dict
4+ from functools import partial
5+ from typing import Iterator , List , Optional , Union , Dict
56from typing_extensions import TypedDict , Literal
67
78import llama_cpp
89
9- from fastapi import Depends , FastAPI , APIRouter
10+ import anyio
11+ from anyio .streams .memory import MemoryObjectSendStream
12+ from starlette .concurrency import run_in_threadpool , iterate_in_threadpool
13+ from fastapi import Depends , FastAPI , APIRouter , Request
1014from fastapi .middleware .cors import CORSMiddleware
1115from pydantic import BaseModel , BaseSettings , Field , create_model_from_typeddict
1216from sse_starlette .sse import EventSourceResponse
@@ -241,35 +245,49 @@ class Config:
241245 "/v1/completions" ,
242246 response_model = CreateCompletionResponse ,
243247)
244- def create_completion (
245- request : CreateCompletionRequest , llama : llama_cpp .Llama = Depends (get_llama )
248+ async def create_completion (
249+ request : Request ,
250+ body : CreateCompletionRequest ,
251+ llama : llama_cpp .Llama = Depends (get_llama ),
246252):
247- if isinstance (request .prompt , list ):
248- assert len (request .prompt ) <= 1
249- request .prompt = request .prompt [0 ] if len (request .prompt ) > 0 else ""
250-
251- completion_or_chunks = llama (
252- ** request .dict (
253- exclude = {
254- "n" ,
255- "best_of" ,
256- "logit_bias" ,
257- "user" ,
258- }
259- )
260- )
261- if request .stream :
262-
263- async def server_sent_events (
264- chunks : Iterator [llama_cpp .CompletionChunk ],
265- ):
266- for chunk in chunks :
267- yield dict (data = json .dumps (chunk ))
253+ if isinstance (body .prompt , list ):
254+ assert len (body .prompt ) <= 1
255+ body .prompt = body .prompt [0 ] if len (body .prompt ) > 0 else ""
256+
257+ exclude = {
258+ "n" ,
259+ "best_of" ,
260+ "logit_bias" ,
261+ "user" ,
262+ }
263+ kwargs = body .dict (exclude = exclude )
264+ if body .stream :
265+ send_chan , recv_chan = anyio .create_memory_object_stream (10 )
266+
267+ async def event_publisher (inner_send_chan : MemoryObjectSendStream ):
268+ async with inner_send_chan :
269+ try :
270+ iterator : Iterator [llama_cpp .CompletionChunk ] = await run_in_threadpool (llama , ** kwargs ) # type: ignore
271+ async for chunk in iterate_in_threadpool (iterator ):
272+ await inner_send_chan .send (dict (data = json .dumps (chunk )))
273+ if await request .is_disconnected ():
274+ raise anyio .get_cancelled_exc_class ()()
275+ await inner_send_chan .send (dict (data = "[DONE]" ))
276+ except anyio .get_cancelled_exc_class () as e :
277+ print ("disconnected" )
278+ with anyio .move_on_after (1 , shield = True ):
279+ print (
280+ f"Disconnected from client (via refresh/close) { request .client } "
281+ )
282+ await inner_send_chan .send (dict (closing = True ))
283+ raise e
268284
269- chunks : Iterator [llama_cpp .CompletionChunk ] = completion_or_chunks # type: ignore
270- return EventSourceResponse (server_sent_events (chunks ))
271- completion : llama_cpp .Completion = completion_or_chunks # type: ignore
272- return completion
285+ return EventSourceResponse (
286+ recv_chan , data_sender_callable = partial (event_publisher , send_chan )
287+ )
288+ else :
289+ completion : llama_cpp .Completion = await run_in_threadpool (llama , ** kwargs ) # type: ignore
290+ return completion
273291
274292
275293class CreateEmbeddingRequest (BaseModel ):
@@ -292,10 +310,12 @@ class Config:
292310 "/v1/embeddings" ,
293311 response_model = CreateEmbeddingResponse ,
294312)
295- def create_embedding (
313+ async def create_embedding (
296314 request : CreateEmbeddingRequest , llama : llama_cpp .Llama = Depends (get_llama )
297315):
298- return llama .create_embedding (** request .dict (exclude = {"user" }))
316+ return await run_in_threadpool (
317+ llama .create_embedding , ** request .dict (exclude = {"user" })
318+ )
299319
300320
301321class ChatCompletionRequestMessage (BaseModel ):
@@ -349,36 +369,47 @@ class Config:
349369 "/v1/chat/completions" ,
350370 response_model = CreateChatCompletionResponse ,
351371)
352- def create_chat_completion (
353- request : CreateChatCompletionRequest ,
372+ async def create_chat_completion (
373+ request : Request ,
374+ body : CreateChatCompletionRequest ,
354375 llama : llama_cpp .Llama = Depends (get_llama ),
355376) -> Union [llama_cpp .ChatCompletion , EventSourceResponse ]:
356- completion_or_chunks = llama .create_chat_completion (
357- ** request .dict (
358- exclude = {
359- "n" ,
360- "logit_bias" ,
361- "user" ,
362- }
363- ),
364- )
365-
366- if request .stream :
367-
368- async def server_sent_events (
369- chat_chunks : Iterator [llama_cpp .ChatCompletionChunk ],
370- ):
371- for chat_chunk in chat_chunks :
372- yield dict (data = json .dumps (chat_chunk ))
373- yield dict (data = "[DONE]" )
374-
375- chunks : Iterator [llama_cpp .ChatCompletionChunk ] = completion_or_chunks # type: ignore
377+ exclude = {
378+ "n" ,
379+ "logit_bias" ,
380+ "user" ,
381+ }
382+ kwargs = body .dict (exclude = exclude )
383+ if body .stream :
384+ send_chan , recv_chan = anyio .create_memory_object_stream (10 )
385+
386+ async def event_publisher (inner_send_chan : MemoryObjectSendStream ):
387+ async with inner_send_chan :
388+ try :
389+ iterator : Iterator [llama_cpp .ChatCompletionChunk ] = await run_in_threadpool (llama .create_chat_completion , ** kwargs ) # type: ignore
390+ async for chat_chunk in iterate_in_threadpool (iterator ):
391+ await inner_send_chan .send (dict (data = json .dumps (chat_chunk )))
392+ if await request .is_disconnected ():
393+ raise anyio .get_cancelled_exc_class ()()
394+ await inner_send_chan .send (dict (data = "[DONE]" ))
395+ except anyio .get_cancelled_exc_class () as e :
396+ print ("disconnected" )
397+ with anyio .move_on_after (1 , shield = True ):
398+ print (
399+ f"Disconnected from client (via refresh/close) { request .client } "
400+ )
401+ await inner_send_chan .send (dict (closing = True ))
402+ raise e
376403
377404 return EventSourceResponse (
378- server_sent_events (chunks ),
405+ recv_chan ,
406+ data_sender_callable = partial (event_publisher , send_chan ),
407+ )
408+ else :
409+ completion : llama_cpp .ChatCompletion = await run_in_threadpool (
410+ llama .create_chat_completion , ** kwargs # type: ignore
379411 )
380- completion : llama_cpp .ChatCompletion = completion_or_chunks # type: ignore
381- return completion
412+ return completion
382413
383414
384415class ModelData (TypedDict ):
@@ -397,7 +428,7 @@ class ModelList(TypedDict):
397428
398429
399430@router .get ("/v1/models" , response_model = GetModelResponse )
400- def get_models (
431+ async def get_models (
401432 settings : Settings = Depends (get_settings ),
402433 llama : llama_cpp .Llama = Depends (get_llama ),
403434) -> ModelList :
0 commit comments