1
1
import json
2
- import logging
3
2
import re
4
3
from abc import ABC , abstractmethod
5
4
from typing import Any , AsyncGenerator , Optional , Union
6
5
7
- from openai .types .chat import (
8
- ChatCompletion ,
9
- ChatCompletionContentPartParam ,
10
- ChatCompletionMessageParam ,
11
- )
6
+ from openai .types .chat import ChatCompletion , ChatCompletionMessageParam
12
7
13
8
from approaches .approach import Approach
14
- from core .messagebuilder import MessageBuilder
15
9
16
10
17
11
class ChatApproach (Approach , ABC ):
18
- # Chat roles
19
- SYSTEM = "system"
20
- USER = "user"
21
- ASSISTANT = "assistant"
22
-
23
- query_prompt_few_shots = [
24
- {"role" : USER , "content" : "How did crypto do last year?" },
25
- {"role" : ASSISTANT , "content" : "Summarize Cryptocurrency Market Dynamics from last year" },
26
- {"role" : USER , "content" : "What are my health plans?" },
27
- {"role" : ASSISTANT , "content" : "Show available health plans" },
12
+ query_prompt_few_shots : list [ChatCompletionMessageParam ] = [
13
+ {"role" : "user" , "content" : "How did crypto do last year?" },
14
+ {"role" : "assistant" , "content" : "Summarize Cryptocurrency Market Dynamics from last year" },
15
+ {"role" : "user" , "content" : "What are my health plans?" },
16
+ {"role" : "assistant" , "content" : "Show available health plans" },
28
17
]
29
18
NO_RESPONSE = "0"
30
19
@@ -53,7 +42,7 @@ def system_message_chat_conversation(self) -> str:
53
42
pass
54
43
55
44
@abstractmethod
56
- async def run_until_final_call (self , history , overrides , auth_claims , should_stream ) -> tuple :
45
+ async def run_until_final_call (self , messages , overrides , auth_claims , should_stream ) -> tuple :
57
46
pass
58
47
59
48
def get_system_prompt (self , override_prompt : Optional [str ], follow_up_questions_prompt : str ) -> str :
@@ -89,48 +78,15 @@ def get_search_query(self, chat_completion: ChatCompletion, user_query: str):
89
78
def extract_followup_questions (self , content : str ):
90
79
return content .split ("<<" )[0 ], re .findall (r"<<([^>>]+)>>" , content )
91
80
92
- def get_messages_from_history (
93
- self ,
94
- system_prompt : str ,
95
- model_id : str ,
96
- history : list [dict [str , str ]],
97
- user_content : Union [str , list [ChatCompletionContentPartParam ]],
98
- max_tokens : int ,
99
- few_shots = [],
100
- ) -> list [ChatCompletionMessageParam ]:
101
- message_builder = MessageBuilder (system_prompt , model_id )
102
-
103
- # Add examples to show the chat what responses we want. It will try to mimic any responses and make sure they match the rules laid out in the system message.
104
- for shot in reversed (few_shots ):
105
- message_builder .insert_message (shot .get ("role" ), shot .get ("content" ))
106
-
107
- append_index = len (few_shots ) + 1
108
-
109
- message_builder .insert_message (self .USER , user_content , index = append_index )
110
-
111
- total_token_count = 0
112
- for existing_message in message_builder .messages :
113
- total_token_count += message_builder .count_tokens_for_message (existing_message )
114
-
115
- newest_to_oldest = list (reversed (history [:- 1 ]))
116
- for message in newest_to_oldest :
117
- potential_message_count = message_builder .count_tokens_for_message (message )
118
- if (total_token_count + potential_message_count ) > max_tokens :
119
- logging .info ("Reached max tokens of %d, history will be truncated" , max_tokens )
120
- break
121
- message_builder .insert_message (message ["role" ], message ["content" ], index = append_index )
122
- total_token_count += potential_message_count
123
- return message_builder .messages
124
-
125
81
async def run_without_streaming (
126
82
self ,
127
- history : list [dict [ str , str ] ],
83
+ messages : list [ChatCompletionMessageParam ],
128
84
overrides : dict [str , Any ],
129
85
auth_claims : dict [str , Any ],
130
86
session_state : Any = None ,
131
87
) -> dict [str , Any ]:
132
88
extra_info , chat_coroutine = await self .run_until_final_call (
133
- history , overrides , auth_claims , should_stream = False
89
+ messages , overrides , auth_claims , should_stream = False
134
90
)
135
91
chat_completion_response : ChatCompletion = await chat_coroutine
136
92
chat_resp = chat_completion_response .model_dump () # Convert to dict to make it JSON serializable
@@ -144,18 +100,18 @@ async def run_without_streaming(
144
100
145
101
async def run_with_streaming (
146
102
self ,
147
- history : list [dict [ str , str ] ],
103
+ messages : list [ChatCompletionMessageParam ],
148
104
overrides : dict [str , Any ],
149
105
auth_claims : dict [str , Any ],
150
106
session_state : Any = None ,
151
107
) -> AsyncGenerator [dict , None ]:
152
108
extra_info , chat_coroutine = await self .run_until_final_call (
153
- history , overrides , auth_claims , should_stream = True
109
+ messages , overrides , auth_claims , should_stream = True
154
110
)
155
111
yield {
156
112
"choices" : [
157
113
{
158
- "delta" : {"role" : self . ASSISTANT },
114
+ "delta" : {"role" : "assistant" },
159
115
"context" : extra_info ,
160
116
"session_state" : session_state ,
161
117
"finish_reason" : None ,
@@ -190,7 +146,7 @@ async def run_with_streaming(
190
146
yield {
191
147
"choices" : [
192
148
{
193
- "delta" : {"role" : self . ASSISTANT },
149
+ "delta" : {"role" : "assistant" },
194
150
"context" : {"followup_questions" : followup_questions },
195
151
"finish_reason" : None ,
196
152
"index" : 0 ,
@@ -200,7 +156,11 @@ async def run_with_streaming(
200
156
}
201
157
202
158
async def run (
203
- self , messages : list [dict ], stream : bool = False , session_state : Any = None , context : dict [str , Any ] = {}
159
+ self ,
160
+ messages : list [ChatCompletionMessageParam ],
161
+ stream : bool = False ,
162
+ session_state : Any = None ,
163
+ context : dict [str , Any ] = {},
204
164
) -> Union [dict [str , Any ], AsyncGenerator [dict [str , Any ], None ]]:
205
165
overrides = context .get ("overrides" , {})
206
166
auth_claims = context .get ("auth_claims" , {})
0 commit comments