1
1
import json
2
2
import re
3
3
from abc import ABC , abstractmethod
4
- from typing import Any , AsyncGenerator , Optional , Union
4
+ from typing import Any , AsyncGenerator , Optional
5
5
6
6
from openai .types .chat import ChatCompletion , ChatCompletionMessageParam
7
7
@@ -90,12 +90,13 @@ async def run_without_streaming(
90
90
)
91
91
chat_completion_response : ChatCompletion = await chat_coroutine
92
92
chat_resp = chat_completion_response .model_dump () # Convert to dict to make it JSON serializable
93
- chat_resp ["choices" ][0 ]["context" ] = extra_info
93
+ chat_resp = chat_resp ["choices" ][0 ]
94
+ chat_resp ["context" ] = extra_info
94
95
if overrides .get ("suggest_followup_questions" ):
95
- content , followup_questions = self .extract_followup_questions (chat_resp ["choices" ][ 0 ][ " message" ]["content" ])
96
- chat_resp ["choices" ][ 0 ][ " message" ]["content" ] = content
97
- chat_resp ["choices" ][ 0 ][ " context" ]["followup_questions" ] = followup_questions
98
- chat_resp ["choices" ][ 0 ][ " session_state" ] = session_state
96
+ content , followup_questions = self .extract_followup_questions (chat_resp ["message" ]["content" ])
97
+ chat_resp ["message" ]["content" ] = content
98
+ chat_resp ["context" ]["followup_questions" ] = followup_questions
99
+ chat_resp ["session_state" ] = session_state
99
100
return chat_resp
100
101
101
102
async def run_with_streaming (
@@ -108,64 +109,49 @@ async def run_with_streaming(
108
109
extra_info , chat_coroutine = await self .run_until_final_call (
109
110
messages , overrides , auth_claims , should_stream = True
110
111
)
111
- yield {
112
- "choices" : [
113
- {
114
- "delta" : {"role" : "assistant" },
115
- "context" : extra_info ,
116
- "session_state" : session_state ,
117
- "finish_reason" : None ,
118
- "index" : 0 ,
119
- }
120
- ],
121
- "object" : "chat.completion.chunk" ,
122
- }
112
+ yield {"delta" : {"role" : "assistant" }, "context" : extra_info , "session_state" : session_state }
123
113
124
114
followup_questions_started = False
125
115
followup_content = ""
126
116
async for event_chunk in await chat_coroutine :
127
117
# "2023-07-01-preview" API version has a bug where first response has empty choices
128
118
event = event_chunk .model_dump () # Convert pydantic model to dict
129
119
if event ["choices" ]:
120
+ completion = {"delta" : event ["choices" ][0 ]["delta" ]}
130
121
# if event contains << and not >>, it is start of follow-up question, truncate
131
- content = event [ "choices" ][ 0 ] ["delta" ].get ("content" )
122
+ content = completion ["delta" ].get ("content" )
132
123
content = content or "" # content may either not exist in delta, or explicitly be None
133
124
if overrides .get ("suggest_followup_questions" ) and "<<" in content :
134
125
followup_questions_started = True
135
126
earlier_content = content [: content .index ("<<" )]
136
127
if earlier_content :
137
- event [ "choices" ][ 0 ] ["delta" ]["content" ] = earlier_content
138
- yield event
128
+ completion ["delta" ]["content" ] = earlier_content
129
+ yield completion
139
130
followup_content += content [content .index ("<<" ) :]
140
131
elif followup_questions_started :
141
132
followup_content += content
142
133
else :
143
- yield event
134
+ yield completion
144
135
if followup_content :
145
136
_ , followup_questions = self .extract_followup_questions (followup_content )
146
- yield {
147
- "choices" : [
148
- {
149
- "delta" : {"role" : "assistant" },
150
- "context" : {"followup_questions" : followup_questions },
151
- "finish_reason" : None ,
152
- "index" : 0 ,
153
- }
154
- ],
155
- "object" : "chat.completion.chunk" ,
156
- }
137
+ yield {"delta" : {"role" : "assistant" }, "context" : {"followup_questions" : followup_questions }}
157
138
158
139
async def run (
159
140
self ,
160
141
messages : list [ChatCompletionMessageParam ],
161
- stream : bool = False ,
162
142
session_state : Any = None ,
163
143
context : dict [str , Any ] = {},
164
- ) -> Union [ dict [str , Any ], AsyncGenerator [ dict [ str , Any ], None ] ]:
144
+ ) -> dict [str , Any ]:
165
145
overrides = context .get ("overrides" , {})
166
146
auth_claims = context .get ("auth_claims" , {})
147
+ return await self .run_without_streaming (messages , overrides , auth_claims , session_state )
167
148
168
- if stream is False :
169
- return await self .run_without_streaming (messages , overrides , auth_claims , session_state )
170
- else :
171
- return self .run_with_streaming (messages , overrides , auth_claims , session_state )
149
+ async def run_stream (
150
+ self ,
151
+ messages : list [ChatCompletionMessageParam ],
152
+ session_state : Any = None ,
153
+ context : dict [str , Any ] = {},
154
+ ) -> AsyncGenerator [dict [str , Any ], None ]:
155
+ overrides = context .get ("overrides" , {})
156
+ auth_claims = context .get ("auth_claims" , {})
157
+ return self .run_with_streaming (messages , overrides , auth_claims , session_state )
0 commit comments