1111are accumulated in a single <details> block and then returned together with the final answer.
1212"""
1313
14- from typing import AsyncGenerator
15- import time
1614import json
1715import os
16+ import asyncio
1817
1918from fastapi .responses import JSONResponse , StreamingResponse
19+ from fastapi import FastAPI , HTTPException , APIRouter
2020from fastapi .middleware .cors import CORSMiddleware
21- from fastapi import FastAPI , HTTPException
2221from dotenv import load_dotenv
2322from loguru import logger
23+ import uvicorn
2424import httpx
2525
26- from utils .classes import ChatCompletionRequest
26+ from utils .classes import (
27+ ChatCompletionRequest ,
28+ ChatCompletionResponse ,
29+ CONTACT_US_MAP ,
30+ MessageModel ,
31+ ChoiceModel ,
32+ )
2733from utils .llm .pipeline import Pipeline
2834
2935load_dotenv ()
4046
4147logger .info (f"Using OpenAI API Base URL: { OPENAI_API_BASE_URL } " )
4248
43-
44- # ----------------------------------------------------------------------
45- # Event Aggregator: For final message assembly
46- # ----------------------------------------------------------------------
47- class EventAggregator :
48- def __init__ (self ):
49- self .buffer = ""
50-
51- async def __call__ (self , event : dict ):
52- if event .get ("type" ) == "replace" :
53- self .buffer = event .get ("data" , {}).get ("content" , "" )
54- else :
55- self .buffer += event .get ("data" , {}).get ("content" , "" )
56-
57- def get_buffer (self ) -> str :
58- return self .buffer
59-
60-
6149# ----------------------------------------------------------------------
6250# FastAPI App and Endpoints
6351# ----------------------------------------------------------------------
6452app = FastAPI (
6553 title = "OpenAI Compatible API with MCTS" ,
6654 description = "Wraps LLM invocations with Monte Carlo Tree Search refinement" ,
67- version = "0.0.1" ,
55+ version = "0.0.91" ,
56+ root_path = "/v1" ,
57+ contact = CONTACT_US_MAP ,
6858)
69-
59+ # CORS middleware
7060app .add_middleware (
7161 CORSMiddleware ,
7262 allow_origins = ["*" ],
7363 allow_methods = ["*" ],
7464 allow_headers = ["*" ],
7565 allow_credentials = True ,
7666)
67+ # Defining routers
68+ model_router = APIRouter (prefix = "/models" , tags = ["Model Management" ])
69+ chat_router = APIRouter (prefix = "/chat" , tags = ["Chat Completions" ])
70+
7771pipeline = Pipeline (
7872 openai_api_base_url = OPENAI_API_BASE_URL , openai_api_key = OPENAI_API_KEY
7973)
8074
8175
82- @app .post ("/v1/chat/completions" )
76+ # Helper function to generate streaming responses.
77+ async def streaming_event_generator (
78+ event_queue : asyncio .Queue , stream_task : asyncio .Task
79+ ):
80+ # Emit the opening <think> block
81+ opening_event = {"choices" : [{"delta" : {"content" : "<think>\n " }}]}
82+ yield f"data: { json .dumps (opening_event )} \n \n "
83+ thinking_closed = False
84+
85+ while True :
86+ try :
87+ event = await asyncio .wait_for (event_queue .get (), timeout = 30 )
88+ except asyncio .TimeoutError :
89+ break
90+
91+ if event .get ("type" ) in ["message" , "replace" ]:
92+ if event .get ("final" ):
93+ if not thinking_closed :
94+ closing_event = {"choices" : [{"delta" : {"content" : "\n </think>" }}]}
95+ yield f"data: { json .dumps (closing_event )} \n \n "
96+ thinking_closed = True
97+ # Send the final answer separately.
98+ chunk = {
99+ "choices" : [
100+ {
101+ "delta" : {
102+ "content" : event ["data" ].get ("reasoning_content" , "" )
103+ }
104+ }
105+ ]
106+ }
107+ yield f"data: { json .dumps (chunk )} \n \n "
108+ else :
109+ # For intermediate tokens, strip accidental <think> markers.
110+ token = event ["data" ].get ("reasoning_content" , "" )
111+ token = token .replace ("<think>\n " , "" ).replace ("\n </think>" , "" )
112+ chunk = {"choices" : [{"delta" : {"content" : token }}]}
113+ yield f"data: { json .dumps (chunk )} \n \n "
114+
115+ if event .get ("done" ):
116+ break
117+
118+ yield "data: [DONE]\n \n "
119+ await stream_task
120+
121+
122+ # Helper function to accumulate tokens for non-streaming response.
123+ async def accumulate_tokens (
124+ event_queue : asyncio .Queue , stream_task : asyncio .Task
125+ ) -> str :
126+ collected = ""
127+ in_block = False
128+
129+ while True :
130+ try :
131+ event = await asyncio .wait_for (event_queue .get (), timeout = 30 )
132+ except asyncio .TimeoutError :
133+ break
134+
135+ if event .get ("type" ) in ["message" , "replace" ]:
136+ token = event ["data" ].get ("reasoning_content" , "" )
137+ # Start a <think> block only once.
138+ if not in_block :
139+ collected += "<think>\n "
140+ in_block = True
141+ collected += token
142+ if event .get ("block_end" , False ):
143+ collected += "\n </think>"
144+ in_block = False
145+ if event .get ("done" ):
146+ if in_block :
147+ collected += "\n </think>"
148+ in_block = False
149+ break
150+
151+ await stream_task
152+ collected = collected .rstrip ()
153+ if collected .endswith ("</think>" ):
154+ collected = collected [: - len ("</think>" )].rstrip ()
155+ return collected
156+
157+
158+ @chat_router .post ("/completions" , response_model = ChatCompletionResponse )
83159async def chat_completions (request : ChatCompletionRequest ):
160+ """
161+ Handles chat completion requests by processing input through a pipeline and
162+ returning the generated response. Supports both streaming and non-streaming
163+ modes based on the request. Refer to the ChatCompletionRequest and
164+ ReasoningEffort schemas for more information.
165+
166+ ## Args:
167+ - `request` (`ChatCompletionRequest`): The input request containing model
168+ details and streaming preference.
169+
170+ ## Returns:
171+ - `dict` or `StreamingResponse`: A JSON response with the generated chat
172+ completion, either as a single response or streamed chunks.
173+ """ # To collect streamed events.
174+ event_queue = asyncio .Queue ()
175+
176+ # Emitter: push events (dictionaries) into the queue.
177+ async def emitter (event : dict ):
178+ await event_queue .put (event )
179+
180+ # Launch the streaming pipeline task.
181+ stream_task = asyncio .create_task (pipeline .run_stream (request , emitter ))
182+
84183 if request .stream :
85- aggregator = EventAggregator ()
86- final_text = await pipeline .run (request , aggregator )
87- full_message = aggregator .get_buffer () + "\n " + final_text
88- final_response = {
89- "id" : "mcts_response" ,
90- "object" : "chat.completion" ,
91- "created" : time .time (),
92- "model" : request .model ,
93- "choices" : [{"message" : {"role" : "assistant" , "content" : full_message }}],
94- }
95-
96- # Return a single JSON chunk with mimetype application/json
97- async def single_chunk () -> AsyncGenerator [str , None ]:
98- yield json .dumps (final_response )
99-
100- return StreamingResponse (single_chunk (), media_type = "application/json" )
184+ return StreamingResponse (
185+ streaming_event_generator (event_queue , stream_task ),
186+ media_type = "text/event-stream" ,
187+ )
101188 else :
102- aggregator = EventAggregator ()
103- final_text = await pipeline .run (request , aggregator )
104- full_message = aggregator .get_buffer () + "\n " + final_text
105- return {
106- "id" : "mcts_response" ,
107- "object" : "chat.completion" ,
108- "created" : time .time (),
109- "model" : request .model ,
110- "choices" : [{"message" : {"role" : "assistant" , "content" : full_message }}],
111- }
112-
113-
114- @app .get ("/v1/models" )
189+ collected = await accumulate_tokens (event_queue , stream_task )
190+ chat_response = ChatCompletionResponse (
191+ model = request .model ,
192+ choices = [
193+ ChoiceModel (
194+ message = MessageModel (
195+ reasoning_content = collected ,
196+ content = collected ,
197+ )
198+ )
199+ ],
200+ )
201+ return JSONResponse (content = chat_response .model_dump ())
202+
203+
204+ @model_router .get ("" , response_description = "Proxied JSON Response" )
115205async def list_models ():
206+ """
207+ Asynchronously fetches the list of models from the OpenAI API.
208+ Sends a `GET` request to the models endpoint and returns the JSON via a proxy.
209+ """
116210 url = f"{ OPENAI_API_BASE_URL } /models"
117211 async with httpx .AsyncClient () as client :
118212 resp = await client .get (
@@ -126,7 +220,8 @@ async def list_models():
126220 return JSONResponse (content = data )
127221
128222
129- if __name__ == "__main__" :
130- import uvicorn
223+ app . include_router ( model_router )
224+ app . include_router ( chat_router )
131225
226+ if __name__ == "__main__" :
132227 uvicorn .run ("app:app" , host = "0.0.0.0" , port = 8000 , reload = True )
0 commit comments