Skip to content

Commit b4c45b1

Browse files
authored
Adds optional streaming support for chat requests (#532)
* Quart draft * Fix ask and test * Quart deploying now * Use semantic * Get tests working * Revert simple * Typing fixes * dont use pipe * Initial draft of streaming for quart * Get streaming working * Revert unneeded changes * Update version * Use __anext__ * Update package lock * Update ndjson output per recommendations * Typing changes * Split into 2 endpoints, more tests * Formatting fixes * Sources moved to end * Dont use session workaround for streaming * Ignore timeout type issue
1 parent e1a077d commit b4c45b1

File tree

25 files changed

+424
-75
lines changed

25 files changed

+424
-75
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
exclude: '^tests/snapshots/'
12
repos:
23
- repo: https://github.com/pre-commit/pre-commit-hooks
34
rev: v4.4.0

app/backend/app.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import io
2+
import json
23
import logging
34
import mimetypes
45
import os
56
import time
7+
from typing import AsyncGenerator
68

79
import aiohttp
810
import openai
@@ -18,6 +20,7 @@
1820
abort,
1921
current_app,
2022
jsonify,
23+
make_response,
2124
request,
2225
send_file,
2326
send_from_directory,
@@ -97,12 +100,36 @@ async def chat():
97100
# Workaround for: https://github.com/openai/openai-python/issues/371
98101
async with aiohttp.ClientSession() as s:
99102
openai.aiosession.set(s)
100-
r = await impl.run(request_json["history"], request_json.get("overrides") or {})
103+
r = await impl.run_without_streaming(request_json["history"], request_json.get("overrides", {}))
101104
return jsonify(r)
102105
except Exception as e:
103106
logging.exception("Exception in /chat")
104107
return jsonify({"error": str(e)}), 500
105108

109+
110+
async def format_as_ndjson(r: AsyncGenerator[dict, None]) -> AsyncGenerator[str, None]:
111+
async for event in r:
112+
yield json.dumps(event, ensure_ascii=False) + "\n"
113+
114+
@bp.route("/chat_stream", methods=["POST"])
115+
async def chat_stream():
116+
if not request.is_json:
117+
return jsonify({"error": "request must be json"}), 415
118+
request_json = await request.get_json()
119+
approach = request_json["approach"]
120+
try:
121+
impl = current_app.config[CONFIG_CHAT_APPROACHES].get(approach)
122+
if not impl:
123+
return jsonify({"error": "unknown approach"}), 400
124+
response_generator = impl.run_with_streaming(request_json["history"], request_json.get("overrides", {}))
125+
response = await make_response(format_as_ndjson(response_generator))
126+
response.timeout = None # type: ignore
127+
return response
128+
except Exception as e:
129+
logging.exception("Exception in /chat")
130+
return jsonify({"error": str(e)}), 500
131+
132+
106133
@bp.before_request
107134
async def ensure_openai_token():
108135
openai_token = current_app.config[CONFIG_OPENAI_TOKEN]

app/backend/approaches/approach.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,7 @@
22
from typing import Any
33

44

5-
class ChatApproach(ABC):
6-
@abstractmethod
7-
async def run(self, history: list[dict], overrides: dict[str, Any]) -> Any:
8-
...
9-
10-
115
class AskApproach(ABC):
126
@abstractmethod
13-
async def run(self, q: str, overrides: dict[str, Any]) -> Any:
7+
async def run(self, q: str, overrides: dict[str, Any]) -> dict[str, Any]:
148
...

app/backend/approaches/chatreadretrieveread.py

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,15 @@
1-
from typing import Any
1+
from typing import Any, AsyncGenerator
22

33
import openai
44
from azure.search.documents.aio import SearchClient
55
from azure.search.documents.models import QueryType
66

7-
from approaches.approach import ChatApproach
87
from core.messagebuilder import MessageBuilder
98
from core.modelhelper import get_token_limit
109
from text import nonewlines
1110

1211

13-
class ChatReadRetrieveReadApproach(ChatApproach):
12+
class ChatReadRetrieveReadApproach:
1413
# Chat roles
1514
SYSTEM = "system"
1615
USER = "user"
@@ -57,7 +56,7 @@ def __init__(self, search_client: SearchClient, chatgpt_deployment: str, chatgpt
5756
self.content_field = content_field
5857
self.chatgpt_token_limit = get_token_limit(chatgpt_model)
5958

60-
async def run(self, history: list[dict[str, str]], overrides: dict[str, Any]) -> Any:
59+
async def run_until_final_call(self, history: list[dict[str, str]], overrides: dict[str, Any], should_stream: bool = False) -> tuple:
6160
has_text = overrides.get("retrieval_mode") in ["text", "hybrid", None]
6261
has_vector = overrides.get("retrieval_mode") in ["vectors", "hybrid", None]
6362
use_semantic_captions = True if overrides.get("semantic_captions") and has_text else False
@@ -146,20 +145,31 @@ async def run(self, history: list[dict[str, str]], overrides: dict[str, Any]) ->
146145
history,
147146
history[-1]["user"]+ "\n\nSources:\n" + content, # Model does not handle lengthy system messages well. Moving sources to latest user conversation to solve follow up questions prompt.
148147
max_tokens=self.chatgpt_token_limit)
149-
150-
chat_completion = await openai.ChatCompletion.acreate(
151-
deployment_id=self.chatgpt_deployment,
152-
model=self.chatgpt_model,
153-
messages=messages,
154-
temperature=overrides.get("temperature") or 0.7,
155-
max_tokens=1024,
156-
n=1)
157-
158-
chat_content = chat_completion.choices[0].message.content
159-
160148
msg_to_display = '\n\n'.join([str(message) for message in messages])
161149

162-
return {"data_points": results, "answer": chat_content, "thoughts": f"Searched for:<br>{query_text}<br><br>Conversations:<br>" + msg_to_display.replace('\n', '<br>')}
150+
extra_info = {"data_points": results, "thoughts": f"Searched for:<br>{query_text}<br><br>Conversations:<br>" + msg_to_display.replace('\n', '<br>')}
151+
chat_coroutine = openai.ChatCompletion.acreate(
152+
deployment_id=self.chatgpt_deployment,
153+
model=self.chatgpt_model,
154+
messages=messages,
155+
temperature=overrides.get("temperature") or 0.7,
156+
max_tokens=1024,
157+
n=1,
158+
stream=should_stream)
159+
return (extra_info, chat_coroutine)
160+
161+
async def run_without_streaming(self, history: list[dict[str, str]], overrides: dict[str, Any]) -> dict[str, Any]:
162+
extra_info, chat_coroutine = await self.run_until_final_call(history, overrides, should_stream=False)
163+
chat_content = (await chat_coroutine).choices[0].message.content
164+
extra_info["answer"] = chat_content
165+
return extra_info
166+
167+
async def run_with_streaming(self, history: list[dict[str, str]], overrides: dict[str, Any]) -> AsyncGenerator[dict, None]:
168+
extra_info, chat_coroutine = await self.run_until_final_call(history, overrides, should_stream=True)
169+
yield extra_info
170+
async for event in await chat_coroutine:
171+
yield event
172+
163173

164174
def get_messages_from_history(self, system_prompt: str, model_id: str, history: list[dict[str, str]], user_conv: str, few_shots = [], max_tokens: int = 4096) -> list:
165175
message_builder = MessageBuilder(system_prompt, model_id)

app/backend/approaches/readdecomposeask.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ async def lookup(self, q: str) -> Optional[str]:
8585
return "\n".join([d['content'] async for d in r])
8686
return None
8787

88-
async def run(self, q: str, overrides: dict[str, Any]) -> Any:
88+
async def run(self, q: str, overrides: dict[str, Any]) -> dict[str, Any]:
8989

9090
search_results = None
9191
async def search_and_store(q: str) -> Any:

app/backend/approaches/readretrieveread.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ async def retrieve(self, query_text: str, overrides: dict[str, Any]) -> Any:
9999
content = "\n".join(results)
100100
return results, content
101101

102-
async def run(self, q: str, overrides: dict[str, Any]) -> Any:
102+
async def run(self, q: str, overrides: dict[str, Any]) -> dict[str, Any]:
103103

104104
retrieve_results = None
105105
async def retrieve_and_store(q: str) -> Any:

app/backend/approaches/retrievethenread.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def __init__(self, search_client: SearchClient, openai_deployment: str, chatgpt_
4444
self.sourcepage_field = sourcepage_field
4545
self.content_field = content_field
4646

47-
async def run(self, q: str, overrides: dict[str, Any]) -> Any:
47+
async def run(self, q: str, overrides: dict[str, Any]) -> dict[str, Any]:
4848
has_text = overrides.get("retrieval_mode") in ["text", "hybrid", None]
4949
has_vector = overrides.get("retrieval_mode") in ["vectors", "hybrid", None]
5050
use_semantic_captions = True if overrides.get("semantic_captions") and has_text else False

app/frontend/package-lock.json

Lines changed: 11 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

app/frontend/package.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
"dompurify": "^3.0.4",
1919
"react": "^18.2.0",
2020
"react-dom": "^18.2.0",
21-
"react-router-dom": "^6.14.1"
21+
"react-router-dom": "^6.14.1",
22+
"ndjson-readablestream": "^1.0.6"
2223
},
2324
"devDependencies": {
2425
"@types/dompurify": "^3.0.2",

app/frontend/src/api/api.ts

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,9 @@ export async function askApi(options: AskRequest): Promise<AskResponse> {
3131
return parsedResponse;
3232
}
3333

34-
export async function chatApi(options: ChatRequest): Promise<AskResponse> {
35-
const response = await fetch("/chat", {
34+
export async function chatApi(options: ChatRequest): Promise<Response> {
35+
const url = options.shouldStream ? "/chat_stream" : "/chat";
36+
return await fetch(url, {
3637
method: "POST",
3738
headers: {
3839
"Content-Type": "application/json"
@@ -54,13 +55,6 @@ export async function chatApi(options: ChatRequest): Promise<AskResponse> {
5455
}
5556
})
5657
});
57-
58-
const parsedResponse: AskResponse = await response.json();
59-
if (response.status > 299 || !response.ok) {
60-
throw Error(parsedResponse.error || "Unknown error");
61-
}
62-
63-
return parsedResponse;
6458
}
6559

6660
export function getCitationFilePath(citation: string): string {

0 commit comments

Comments
 (0)