Skip to content

Commit ec5991c

Browse files
authored
Support for Streaming and System Prompt (#2)
# Release v0.0.9 * Push thinking tokens to reasoning_content * Support system prompt + Improve API docs * Fake streaming adds <think> tags for Chat Interfaces * Implement real streaming support in chat completions * Clean Codebase * Update API docs in chat completions
1 parent b2e421d commit ec5991c

File tree

8 files changed

+401
-106
lines changed

8 files changed

+401
-106
lines changed

README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,8 @@ curl -X 'POST' \
108108
}
109109
],
110110
"max_tokens": 1024,
111-
"temperature": 0.5
111+
"temperature": 0.5,
112+
"reasoning_effort": "low"
112113
}' | jq -r '.choices[0].message.content'
113114
```
114115

@@ -130,7 +131,7 @@ Wraps a chat completion request in an MCTS pipeline that refines the answer by g
130131
| temperature | number (optional) | `0.7` | Controls the randomness of the output. |
131132
| stream | boolean (optional) | `false` | If false, aggregates streamed responses and returns on completion. If true, streams intermediate responses. |
132133
| reasoning_effort | string (optional) | `normal` | Controls the `MCTSAgent` search settings: |
133-
| => | => | => | **`normal`** - 2 iterations, 2 simulations per iteration, and 2 child nodes per parent (default). |
134+
| => | => | => | **`low`** - 2 iterations, 2 simulations per iteration, and 2 child nodes per parent (default). |
134135
| => | => | => | `medium` - 3 iterations, 3 simulations per iteration, and 3 child nodes per parent. |
135136
| => | => | => | `high` - 4 iterations, 4 simulations per iteration, and 4 child nodes per parent. |
136137

app.py

Lines changed: 151 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -11,19 +11,25 @@
1111
are accumulated in a single <details> block and then returned together with the final answer.
1212
"""
1313

14-
from typing import AsyncGenerator
15-
import time
1614
import json
1715
import os
16+
import asyncio
1817

1918
from fastapi.responses import JSONResponse, StreamingResponse
19+
from fastapi import FastAPI, HTTPException, APIRouter
2020
from fastapi.middleware.cors import CORSMiddleware
21-
from fastapi import FastAPI, HTTPException
2221
from dotenv import load_dotenv
2322
from loguru import logger
23+
import uvicorn
2424
import httpx
2525

26-
from utils.classes import ChatCompletionRequest
26+
from utils.classes import (
27+
ChatCompletionRequest,
28+
ChatCompletionResponse,
29+
CONTACT_US_MAP,
30+
MessageModel,
31+
ChoiceModel,
32+
)
2733
from utils.llm.pipeline import Pipeline
2834

2935
load_dotenv()
@@ -40,79 +46,167 @@
4046

4147
logger.info(f"Using OpenAI API Base URL: {OPENAI_API_BASE_URL}")
4248

43-
44-
# ----------------------------------------------------------------------
45-
# Event Aggregator: For final message assembly
46-
# ----------------------------------------------------------------------
47-
class EventAggregator:
48-
def __init__(self):
49-
self.buffer = ""
50-
51-
async def __call__(self, event: dict):
52-
if event.get("type") == "replace":
53-
self.buffer = event.get("data", {}).get("content", "")
54-
else:
55-
self.buffer += event.get("data", {}).get("content", "")
56-
57-
def get_buffer(self) -> str:
58-
return self.buffer
59-
60-
6149
# ----------------------------------------------------------------------
6250
# FastAPI App and Endpoints
6351
# ----------------------------------------------------------------------
6452
app = FastAPI(
6553
title="OpenAI Compatible API with MCTS",
6654
description="Wraps LLM invocations with Monte Carlo Tree Search refinement",
67-
version="0.0.1",
55+
version="0.0.91",
56+
root_path="/v1",
57+
contact=CONTACT_US_MAP,
6858
)
69-
59+
# CORS middleware
7060
app.add_middleware(
7161
CORSMiddleware,
7262
allow_origins=["*"],
7363
allow_methods=["*"],
7464
allow_headers=["*"],
7565
allow_credentials=True,
7666
)
67+
# Defining routers
68+
model_router = APIRouter(prefix="/models", tags=["Model Management"])
69+
chat_router = APIRouter(prefix="/chat", tags=["Chat Completions"])
70+
7771
pipeline = Pipeline(
7872
openai_api_base_url=OPENAI_API_BASE_URL, openai_api_key=OPENAI_API_KEY
7973
)
8074

8175

82-
@app.post("/v1/chat/completions")
76+
# Helper function to generate streaming responses.
77+
async def streaming_event_generator(
78+
event_queue: asyncio.Queue, stream_task: asyncio.Task
79+
):
80+
# Emit the opening <think> block
81+
opening_event = {"choices": [{"delta": {"content": "<think>\n"}}]}
82+
yield f"data: {json.dumps(opening_event)}\n\n"
83+
thinking_closed = False
84+
85+
while True:
86+
try:
87+
event = await asyncio.wait_for(event_queue.get(), timeout=30)
88+
except asyncio.TimeoutError:
89+
break
90+
91+
if event.get("type") in ["message", "replace"]:
92+
if event.get("final"):
93+
if not thinking_closed:
94+
closing_event = {"choices": [{"delta": {"content": "\n</think>"}}]}
95+
yield f"data: {json.dumps(closing_event)}\n\n"
96+
thinking_closed = True
97+
# Send the final answer separately.
98+
chunk = {
99+
"choices": [
100+
{
101+
"delta": {
102+
"content": event["data"].get("reasoning_content", "")
103+
}
104+
}
105+
]
106+
}
107+
yield f"data: {json.dumps(chunk)}\n\n"
108+
else:
109+
# For intermediate tokens, strip accidental <think> markers.
110+
token = event["data"].get("reasoning_content", "")
111+
token = token.replace("<think>\n", "").replace("\n</think>", "")
112+
chunk = {"choices": [{"delta": {"content": token}}]}
113+
yield f"data: {json.dumps(chunk)}\n\n"
114+
115+
if event.get("done"):
116+
break
117+
118+
yield "data: [DONE]\n\n"
119+
await stream_task
120+
121+
122+
# Helper function to accumulate tokens for non-streaming response.
123+
async def accumulate_tokens(
124+
event_queue: asyncio.Queue, stream_task: asyncio.Task
125+
) -> str:
126+
collected = ""
127+
in_block = False
128+
129+
while True:
130+
try:
131+
event = await asyncio.wait_for(event_queue.get(), timeout=30)
132+
except asyncio.TimeoutError:
133+
break
134+
135+
if event.get("type") in ["message", "replace"]:
136+
token = event["data"].get("reasoning_content", "")
137+
# Start a <think> block only once.
138+
if not in_block:
139+
collected += "<think>\n"
140+
in_block = True
141+
collected += token
142+
if event.get("block_end", False):
143+
collected += "\n</think>"
144+
in_block = False
145+
if event.get("done"):
146+
if in_block:
147+
collected += "\n</think>"
148+
in_block = False
149+
break
150+
151+
await stream_task
152+
collected = collected.rstrip()
153+
if collected.endswith("</think>"):
154+
collected = collected[: -len("</think>")].rstrip()
155+
return collected
156+
157+
158+
@chat_router.post("/completions", response_model=ChatCompletionResponse)
83159
async def chat_completions(request: ChatCompletionRequest):
160+
"""
161+
Handles chat completion requests by processing input through a pipeline and
162+
returning the generated response. Supports both streaming and non-streaming
163+
modes based on the request. Refer to the ChatCompletionRequest and
164+
ReasoningEffort schemas for more information.
165+
166+
## Args:
167+
- `request` (`ChatCompletionRequest`): The input request containing model
168+
details and streaming preference.
169+
170+
## Returns:
171+
- `dict` or `StreamingResponse`: A JSON response with the generated chat
172+
completion, either as a single response or streamed chunks.
173+
""" # To collect streamed events.
174+
event_queue = asyncio.Queue()
175+
176+
# Emitter: push events (dictionaries) into the queue.
177+
async def emitter(event: dict):
178+
await event_queue.put(event)
179+
180+
# Launch the streaming pipeline task.
181+
stream_task = asyncio.create_task(pipeline.run_stream(request, emitter))
182+
84183
if request.stream:
85-
aggregator = EventAggregator()
86-
final_text = await pipeline.run(request, aggregator)
87-
full_message = aggregator.get_buffer() + "\n" + final_text
88-
final_response = {
89-
"id": "mcts_response",
90-
"object": "chat.completion",
91-
"created": time.time(),
92-
"model": request.model,
93-
"choices": [{"message": {"role": "assistant", "content": full_message}}],
94-
}
95-
96-
# Return a single JSON chunk with mimetype application/json
97-
async def single_chunk() -> AsyncGenerator[str, None]:
98-
yield json.dumps(final_response)
99-
100-
return StreamingResponse(single_chunk(), media_type="application/json")
184+
return StreamingResponse(
185+
streaming_event_generator(event_queue, stream_task),
186+
media_type="text/event-stream",
187+
)
101188
else:
102-
aggregator = EventAggregator()
103-
final_text = await pipeline.run(request, aggregator)
104-
full_message = aggregator.get_buffer() + "\n" + final_text
105-
return {
106-
"id": "mcts_response",
107-
"object": "chat.completion",
108-
"created": time.time(),
109-
"model": request.model,
110-
"choices": [{"message": {"role": "assistant", "content": full_message}}],
111-
}
112-
113-
114-
@app.get("/v1/models")
189+
collected = await accumulate_tokens(event_queue, stream_task)
190+
chat_response = ChatCompletionResponse(
191+
model=request.model,
192+
choices=[
193+
ChoiceModel(
194+
message=MessageModel(
195+
reasoning_content=collected,
196+
content=collected,
197+
)
198+
)
199+
],
200+
)
201+
return JSONResponse(content=chat_response.model_dump())
202+
203+
204+
@model_router.get("", response_description="Proxied JSON Response")
115205
async def list_models():
206+
"""
207+
Asynchronously fetches the list of models from the OpenAI API.
208+
Sends a `GET` request to the models endpoint and returns the JSON via a proxy.
209+
"""
116210
url = f"{OPENAI_API_BASE_URL}/models"
117211
async with httpx.AsyncClient() as client:
118212
resp = await client.get(
@@ -126,7 +220,8 @@ async def list_models():
126220
return JSONResponse(content=data)
127221

128222

129-
if __name__ == "__main__":
130-
import uvicorn
223+
app.include_router(model_router)
224+
app.include_router(chat_router)
131225

226+
if __name__ == "__main__":
132227
uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=True)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "mcts-openai-api"
3-
version = "0.1.0"
3+
version = "0.0.91"
44
description = "Every incoming request is wrapped with a Monte Carlo Tree Search (MCTS) pipeline"
55
authors = [
66
{name = "Krishnakanth Alagiri",email = "39209037+bearlike@users.noreply.github.com"}

repo-to-prompt.codemod.js

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,34 @@
1-
21
/**
32
* @param {vscode} vscode the entry to vscode plugin api
43
* @param {vscode.Uri} selectedFile currently selected file in vscode explorer
54
* @param {vscode.Uri[]} selectedFiles currently multi-selected files in vscode explorer
65
*/
76
async function run(vscode, selectedFile, selectedFiles) {
87
console.log('You can debug the script with console.log')
8+
9+
// Ask user for repository name
10+
const repoName = "mcts-openai-api";
11+
12+
if (!repoName) {
13+
vscode.window.showErrorMessage('Repository name is required');
14+
return;
15+
}
16+
917
// remove useless file from selectedFiles
1018
selectedFiles = selectedFiles.filter(file => !file.path.endsWith('.env') && !file.path.endsWith('.lock') && !file.path.endsWith('LICENSE'));
1119
const lines = [];
20+
lines.push('\n<details>\n')
1221
for (const file of selectedFiles) {
13-
lines.push('<file path="' + file.path + '">')
22+
// Use regex to remove everything before the repo name
23+
const projectPath = file.path.replace(new RegExp(`^.*?(${repoName}/.*)$`), "$1");
24+
lines.push('<file path="' + projectPath + '">')
1425
lines.push(' ')
1526
lines.push('```')
1627
lines.push(new TextDecoder().decode(await vscode.workspace.fs.readFile(file)))
1728
lines.push('```')
1829
lines.push('</file>')
1930
}
31+
lines.push('\n</details>\n')
2032
await vscode.env.clipboard.writeText(lines.join('\n'))
2133
vscode.window.showInformationMessage('Copied to clipboard as Prompt XML.')
2234
}

0 commit comments

Comments
 (0)