11import json
22import re
33from abc import ABC , abstractmethod
4- from typing import Any , AsyncGenerator , Optional , Union
4+ from typing import Any , AsyncGenerator , Optional
55
66from openai .types .chat import ChatCompletion , ChatCompletionMessageParam
77
@@ -90,12 +90,13 @@ async def run_without_streaming(
9090 )
9191 chat_completion_response : ChatCompletion = await chat_coroutine
9292 chat_resp = chat_completion_response .model_dump () # Convert to dict to make it JSON serializable
93- chat_resp ["choices" ][0 ]["context" ] = extra_info
93+ chat_resp = chat_resp ["choices" ][0 ]
94+ chat_resp ["context" ] = extra_info
9495 if overrides .get ("suggest_followup_questions" ):
95- content , followup_questions = self .extract_followup_questions (chat_resp ["choices" ][ 0 ][ " message" ]["content" ])
96- chat_resp ["choices" ][ 0 ][ " message" ]["content" ] = content
97- chat_resp ["choices" ][ 0 ][ " context" ]["followup_questions" ] = followup_questions
98- chat_resp ["choices" ][ 0 ][ " session_state" ] = session_state
96+ content , followup_questions = self .extract_followup_questions (chat_resp ["message" ]["content" ])
97+ chat_resp ["message" ]["content" ] = content
98+ chat_resp ["context" ]["followup_questions" ] = followup_questions
99+ chat_resp ["session_state" ] = session_state
99100 return chat_resp
100101
101102 async def run_with_streaming (
@@ -108,64 +109,49 @@ async def run_with_streaming(
108109 extra_info , chat_coroutine = await self .run_until_final_call (
109110 messages , overrides , auth_claims , should_stream = True
110111 )
111- yield {
112- "choices" : [
113- {
114- "delta" : {"role" : "assistant" },
115- "context" : extra_info ,
116- "session_state" : session_state ,
117- "finish_reason" : None ,
118- "index" : 0 ,
119- }
120- ],
121- "object" : "chat.completion.chunk" ,
122- }
112+ yield {"delta" : {"role" : "assistant" }, "context" : extra_info , "session_state" : session_state }
123113
124114 followup_questions_started = False
125115 followup_content = ""
126116 async for event_chunk in await chat_coroutine :
127117 # "2023-07-01-preview" API version has a bug where first response has empty choices
128118 event = event_chunk .model_dump () # Convert pydantic model to dict
129119 if event ["choices" ]:
120+ completion = {"delta" : event ["choices" ][0 ]["delta" ]}
130121 # if event contains << and not >>, it is start of follow-up question, truncate
131- content = event [ "choices" ][ 0 ] ["delta" ].get ("content" )
122+ content = completion ["delta" ].get ("content" )
132123 content = content or "" # content may either not exist in delta, or explicitly be None
133124 if overrides .get ("suggest_followup_questions" ) and "<<" in content :
134125 followup_questions_started = True
135126 earlier_content = content [: content .index ("<<" )]
136127 if earlier_content :
137- event [ "choices" ][ 0 ] ["delta" ]["content" ] = earlier_content
138- yield event
128+ completion ["delta" ]["content" ] = earlier_content
129+ yield completion
139130 followup_content += content [content .index ("<<" ) :]
140131 elif followup_questions_started :
141132 followup_content += content
142133 else :
143- yield event
134+ yield completion
144135 if followup_content :
145136 _ , followup_questions = self .extract_followup_questions (followup_content )
146- yield {
147- "choices" : [
148- {
149- "delta" : {"role" : "assistant" },
150- "context" : {"followup_questions" : followup_questions },
151- "finish_reason" : None ,
152- "index" : 0 ,
153- }
154- ],
155- "object" : "chat.completion.chunk" ,
156- }
137+ yield {"delta" : {"role" : "assistant" }, "context" : {"followup_questions" : followup_questions }}
157138
158139 async def run (
159140 self ,
160141 messages : list [ChatCompletionMessageParam ],
161- stream : bool = False ,
162142 session_state : Any = None ,
163143 context : dict [str , Any ] = {},
164- ) -> Union [ dict [str , Any ], AsyncGenerator [ dict [ str , Any ], None ] ]:
144+ ) -> dict [str , Any ]:
165145 overrides = context .get ("overrides" , {})
166146 auth_claims = context .get ("auth_claims" , {})
147+ return await self .run_without_streaming (messages , overrides , auth_claims , session_state )
167148
168- if stream is False :
169- return await self .run_without_streaming (messages , overrides , auth_claims , session_state )
170- else :
171- return self .run_with_streaming (messages , overrides , auth_claims , session_state )
149+ async def run_stream (
150+ self ,
151+ messages : list [ChatCompletionMessageParam ],
152+ session_state : Any = None ,
153+ context : dict [str , Any ] = {},
154+ ) -> AsyncGenerator [dict [str , Any ], None ]:
155+ overrides = context .get ("overrides" , {})
156+ auth_claims = context .get ("auth_claims" , {})
157+ return self .run_with_streaming (messages , overrides , auth_claims , session_state )
0 commit comments