Skip to content

Commit 77303da

Browse files
authored
Improve token counting for messages with package (#1577)
* Disable openai key access * Use message token helper instead * Update to latest package * Revert launch change * Improve typing
1 parent e6fa39f commit 77303da

15 files changed

+109
-650
lines changed

app/backend/approaches/approach.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
VectorQuery,
2424
)
2525
from openai import AsyncOpenAI
26+
from openai.types.chat import ChatCompletionMessageParam
2627

2728
from core.authentication import AuthenticationHelper
2829
from text import nonewlines
@@ -254,6 +255,10 @@ async def compute_image_embedding(self, q: str):
254255
return VectorizedQuery(vector=image_query_vector, k_nearest_neighbors=50, fields="imageEmbedding")
255256

256257
async def run(
257-
self, messages: list[dict], stream: bool = False, session_state: Any = None, context: dict[str, Any] = {}
258+
self,
259+
messages: list[ChatCompletionMessageParam],
260+
stream: bool = False,
261+
session_state: Any = None,
262+
context: dict[str, Any] = {},
258263
) -> Union[dict[str, Any], AsyncGenerator[dict[str, Any], None]]:
259264
raise NotImplementedError

app/backend/approaches/chatapproach.py

Lines changed: 18 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,19 @@
11
import json
2-
import logging
32
import re
43
from abc import ABC, abstractmethod
54
from typing import Any, AsyncGenerator, Optional, Union
65

7-
from openai.types.chat import (
8-
ChatCompletion,
9-
ChatCompletionContentPartParam,
10-
ChatCompletionMessageParam,
11-
)
6+
from openai.types.chat import ChatCompletion, ChatCompletionMessageParam
127

138
from approaches.approach import Approach
14-
from core.messagebuilder import MessageBuilder
159

1610

1711
class ChatApproach(Approach, ABC):
18-
# Chat roles
19-
SYSTEM = "system"
20-
USER = "user"
21-
ASSISTANT = "assistant"
22-
23-
query_prompt_few_shots = [
24-
{"role": USER, "content": "How did crypto do last year?"},
25-
{"role": ASSISTANT, "content": "Summarize Cryptocurrency Market Dynamics from last year"},
26-
{"role": USER, "content": "What are my health plans?"},
27-
{"role": ASSISTANT, "content": "Show available health plans"},
12+
query_prompt_few_shots: list[ChatCompletionMessageParam] = [
13+
{"role": "user", "content": "How did crypto do last year?"},
14+
{"role": "assistant", "content": "Summarize Cryptocurrency Market Dynamics from last year"},
15+
{"role": "user", "content": "What are my health plans?"},
16+
{"role": "assistant", "content": "Show available health plans"},
2817
]
2918
NO_RESPONSE = "0"
3019

@@ -53,7 +42,7 @@ def system_message_chat_conversation(self) -> str:
5342
pass
5443

5544
@abstractmethod
56-
async def run_until_final_call(self, history, overrides, auth_claims, should_stream) -> tuple:
45+
async def run_until_final_call(self, messages, overrides, auth_claims, should_stream) -> tuple:
5746
pass
5847

5948
def get_system_prompt(self, override_prompt: Optional[str], follow_up_questions_prompt: str) -> str:
@@ -89,48 +78,15 @@ def get_search_query(self, chat_completion: ChatCompletion, user_query: str):
8978
def extract_followup_questions(self, content: str):
9079
return content.split("<<")[0], re.findall(r"<<([^>>]+)>>", content)
9180

92-
def get_messages_from_history(
93-
self,
94-
system_prompt: str,
95-
model_id: str,
96-
history: list[dict[str, str]],
97-
user_content: Union[str, list[ChatCompletionContentPartParam]],
98-
max_tokens: int,
99-
few_shots=[],
100-
) -> list[ChatCompletionMessageParam]:
101-
message_builder = MessageBuilder(system_prompt, model_id)
102-
103-
# Add examples to show the chat what responses we want. It will try to mimic any responses and make sure they match the rules laid out in the system message.
104-
for shot in reversed(few_shots):
105-
message_builder.insert_message(shot.get("role"), shot.get("content"))
106-
107-
append_index = len(few_shots) + 1
108-
109-
message_builder.insert_message(self.USER, user_content, index=append_index)
110-
111-
total_token_count = 0
112-
for existing_message in message_builder.messages:
113-
total_token_count += message_builder.count_tokens_for_message(existing_message)
114-
115-
newest_to_oldest = list(reversed(history[:-1]))
116-
for message in newest_to_oldest:
117-
potential_message_count = message_builder.count_tokens_for_message(message)
118-
if (total_token_count + potential_message_count) > max_tokens:
119-
logging.info("Reached max tokens of %d, history will be truncated", max_tokens)
120-
break
121-
message_builder.insert_message(message["role"], message["content"], index=append_index)
122-
total_token_count += potential_message_count
123-
return message_builder.messages
124-
12581
async def run_without_streaming(
12682
self,
127-
history: list[dict[str, str]],
83+
messages: list[ChatCompletionMessageParam],
12884
overrides: dict[str, Any],
12985
auth_claims: dict[str, Any],
13086
session_state: Any = None,
13187
) -> dict[str, Any]:
13288
extra_info, chat_coroutine = await self.run_until_final_call(
133-
history, overrides, auth_claims, should_stream=False
89+
messages, overrides, auth_claims, should_stream=False
13490
)
13591
chat_completion_response: ChatCompletion = await chat_coroutine
13692
chat_resp = chat_completion_response.model_dump() # Convert to dict to make it JSON serializable
@@ -144,18 +100,18 @@ async def run_without_streaming(
144100

145101
async def run_with_streaming(
146102
self,
147-
history: list[dict[str, str]],
103+
messages: list[ChatCompletionMessageParam],
148104
overrides: dict[str, Any],
149105
auth_claims: dict[str, Any],
150106
session_state: Any = None,
151107
) -> AsyncGenerator[dict, None]:
152108
extra_info, chat_coroutine = await self.run_until_final_call(
153-
history, overrides, auth_claims, should_stream=True
109+
messages, overrides, auth_claims, should_stream=True
154110
)
155111
yield {
156112
"choices": [
157113
{
158-
"delta": {"role": self.ASSISTANT},
114+
"delta": {"role": "assistant"},
159115
"context": extra_info,
160116
"session_state": session_state,
161117
"finish_reason": None,
@@ -190,7 +146,7 @@ async def run_with_streaming(
190146
yield {
191147
"choices": [
192148
{
193-
"delta": {"role": self.ASSISTANT},
149+
"delta": {"role": "assistant"},
194150
"context": {"followup_questions": followup_questions},
195151
"finish_reason": None,
196152
"index": 0,
@@ -200,7 +156,11 @@ async def run_with_streaming(
200156
}
201157

202158
async def run(
203-
self, messages: list[dict], stream: bool = False, session_state: Any = None, context: dict[str, Any] = {}
159+
self,
160+
messages: list[ChatCompletionMessageParam],
161+
stream: bool = False,
162+
session_state: Any = None,
163+
context: dict[str, Any] = {},
204164
) -> Union[dict[str, Any], AsyncGenerator[dict[str, Any], None]]:
205165
overrides = context.get("overrides", {})
206166
auth_claims = context.get("auth_claims", {})

app/backend/approaches/chatreadretrieveread.py

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,14 @@
66
from openai.types.chat import (
77
ChatCompletion,
88
ChatCompletionChunk,
9+
ChatCompletionMessageParam,
910
ChatCompletionToolParam,
1011
)
12+
from openai_messages_token_helper import build_messages, get_token_limit
1113

1214
from approaches.approach import ThoughtStep
1315
from approaches.chatapproach import ChatApproach
1416
from core.authentication import AuthenticationHelper
15-
from core.modelhelper import get_token_limit
1617

1718

1819
class ChatReadRetrieveReadApproach(ChatApproach):
@@ -65,7 +66,7 @@ def system_message_chat_conversation(self):
6566
@overload
6667
async def run_until_final_call(
6768
self,
68-
history: list[dict[str, str]],
69+
messages: list[ChatCompletionMessageParam],
6970
overrides: dict[str, Any],
7071
auth_claims: dict[str, Any],
7172
should_stream: Literal[False],
@@ -74,15 +75,15 @@ async def run_until_final_call(
7475
@overload
7576
async def run_until_final_call(
7677
self,
77-
history: list[dict[str, str]],
78+
messages: list[ChatCompletionMessageParam],
7879
overrides: dict[str, Any],
7980
auth_claims: dict[str, Any],
8081
should_stream: Literal[True],
8182
) -> tuple[dict[str, Any], Coroutine[Any, Any, AsyncStream[ChatCompletionChunk]]]: ...
8283

8384
async def run_until_final_call(
8485
self,
85-
history: list[dict[str, str]],
86+
messages: list[ChatCompletionMessageParam],
8687
overrides: dict[str, Any],
8788
auth_claims: dict[str, Any],
8889
should_stream: bool = False,
@@ -97,7 +98,9 @@ async def run_until_final_call(
9798
filter = self.build_filter(overrides, auth_claims)
9899
use_semantic_ranker = True if overrides.get("semantic_ranker") and has_text else False
99100

100-
original_user_query = history[-1]["content"]
101+
original_user_query = messages[-1]["content"]
102+
if not isinstance(original_user_query, str):
103+
raise ValueError("The most recent message content must be a string.")
101104
user_query_request = "Generate search query for: " + original_user_query
102105

103106
tools: List[ChatCompletionToolParam] = [
@@ -121,24 +124,25 @@ async def run_until_final_call(
121124
]
122125

123126
# STEP 1: Generate an optimized keyword search query based on the chat history and the last question
124-
query_messages = self.get_messages_from_history(
127+
query_response_token_limit = 100
128+
query_messages = build_messages(
129+
model=self.chatgpt_model,
125130
system_prompt=self.query_prompt_template,
126-
model_id=self.chatgpt_model,
127-
history=history,
128-
user_content=user_query_request,
129-
max_tokens=self.chatgpt_token_limit - len(user_query_request),
131+
tools=tools,
130132
few_shots=self.query_prompt_few_shots,
133+
past_messages=messages[:-1],
134+
new_user_content=user_query_request,
135+
max_tokens=self.chatgpt_token_limit - query_response_token_limit,
131136
)
132137

133138
chat_completion: ChatCompletion = await self.openai_client.chat.completions.create(
134139
messages=query_messages, # type: ignore
135140
# Azure OpenAI takes the deployment name as the model name
136141
model=self.chatgpt_deployment if self.chatgpt_deployment else self.chatgpt_model,
137142
temperature=0.0, # Minimize creativity for search query generation
138-
max_tokens=100, # Setting too low risks malformed JSON, setting too high may affect performance
143+
max_tokens=query_response_token_limit, # Setting too low risks malformed JSON, setting too high may affect performance
139144
n=1,
140145
tools=tools,
141-
tool_choice="auto",
142146
)
143147

144148
query_text = self.get_search_query(chat_completion, original_user_query)
@@ -177,14 +181,13 @@ async def run_until_final_call(
177181
)
178182

179183
response_token_limit = 1024
180-
messages_token_limit = self.chatgpt_token_limit - response_token_limit
181-
messages = self.get_messages_from_history(
184+
messages = build_messages(
185+
model=self.chatgpt_model,
182186
system_prompt=system_message,
183-
model_id=self.chatgpt_model,
184-
history=history,
187+
past_messages=messages[:-1],
185188
# Model does not handle lengthy system messages well. Moving sources to latest user conversation to solve follow up questions prompt.
186-
user_content=original_user_query + "\n\nSources:\n" + content,
187-
max_tokens=messages_token_limit,
189+
new_user_content=original_user_query + "\n\nSources:\n" + content,
190+
max_tokens=self.chatgpt_token_limit - response_token_limit,
188191
)
189192

190193
data_points = {"text": sources_content}

app/backend/approaches/chatreadretrievereadvision.py

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,14 @@
88
ChatCompletionChunk,
99
ChatCompletionContentPartImageParam,
1010
ChatCompletionContentPartParam,
11+
ChatCompletionMessageParam,
1112
)
13+
from openai_messages_token_helper import build_messages, get_token_limit
1214

1315
from approaches.approach import ThoughtStep
1416
from approaches.chatapproach import ChatApproach
1517
from core.authentication import AuthenticationHelper
1618
from core.imageshelper import fetch_image
17-
from core.modelhelper import get_token_limit
1819

1920

2021
class ChatReadRetrieveReadVisionApproach(ChatApproach):
@@ -79,7 +80,7 @@ def system_message_chat_conversation(self):
7980

8081
async def run_until_final_call(
8182
self,
82-
history: list[dict[str, str]],
83+
messages: list[ChatCompletionMessageParam],
8384
overrides: dict[str, Any],
8485
auth_claims: dict[str, Any],
8586
should_stream: bool = False,
@@ -97,25 +98,29 @@ async def run_until_final_call(
9798
include_gtpV_text = overrides.get("gpt4v_input") in ["textAndImages", "texts", None]
9899
include_gtpV_images = overrides.get("gpt4v_input") in ["textAndImages", "images", None]
99100

100-
original_user_query = history[-1]["content"]
101+
original_user_query = messages[-1]["content"]
102+
if not isinstance(original_user_query, str):
103+
raise ValueError("The most recent message content must be a string.")
104+
past_messages: list[ChatCompletionMessageParam] = messages[:-1]
101105

102106
# STEP 1: Generate an optimized keyword search query based on the chat history and the last question
103107
user_query_request = "Generate search query for: " + original_user_query
104108

105-
query_messages = self.get_messages_from_history(
109+
query_response_token_limit = 100
110+
query_messages = build_messages(
111+
model=self.gpt4v_model,
106112
system_prompt=self.query_prompt_template,
107-
model_id=self.gpt4v_model,
108-
history=history,
109-
user_content=user_query_request,
110-
max_tokens=self.chatgpt_token_limit - len(" ".join(user_query_request)),
111113
few_shots=self.query_prompt_few_shots,
114+
past_messages=past_messages,
115+
new_user_content=user_query_request,
116+
max_tokens=self.chatgpt_token_limit - query_response_token_limit,
112117
)
113118

114119
chat_completion: ChatCompletion = await self.openai_client.chat.completions.create(
115120
model=self.gpt4v_deployment if self.gpt4v_deployment else self.gpt4v_model,
116121
messages=query_messages,
117122
temperature=0.0, # Minimize creativity for search query generation
118-
max_tokens=100,
123+
max_tokens=query_response_token_limit,
119124
n=1,
120125
)
121126

@@ -159,9 +164,6 @@ async def run_until_final_call(
159164
self.follow_up_questions_prompt_content if overrides.get("suggest_followup_questions") else "",
160165
)
161166

162-
response_token_limit = 1024
163-
messages_token_limit = self.chatgpt_token_limit - response_token_limit
164-
165167
user_content: list[ChatCompletionContentPartParam] = [{"text": original_user_query, "type": "text"}]
166168
image_list: list[ChatCompletionContentPartImageParam] = []
167169

@@ -174,12 +176,13 @@ async def run_until_final_call(
174176
image_list.append({"image_url": url, "type": "image_url"})
175177
user_content.extend(image_list)
176178

177-
messages = self.get_messages_from_history(
179+
response_token_limit = 1024
180+
messages = build_messages(
181+
model=self.gpt4v_model,
178182
system_prompt=system_message,
179-
model_id=self.gpt4v_model,
180-
history=history,
181-
user_content=user_content,
182-
max_tokens=messages_token_limit,
183+
past_messages=messages[:-1],
184+
new_user_content=user_content,
185+
max_tokens=self.chatgpt_token_limit - response_token_limit,
183186
)
184187

185188
data_points = {

0 commit comments

Comments
 (0)