|
25 | 25 | from typing import TYPE_CHECKING
|
26 | 26 |
|
27 | 27 | from google.genai import types
|
| 28 | +from websockets.exceptions import ConnectionClosed |
28 | 29 | from websockets.exceptions import ConnectionClosedOK
|
29 | 30 |
|
30 | 31 | from . import functions
|
@@ -86,80 +87,115 @@ async def run_live(
|
86 | 87 | invocation_context.agent.name,
|
87 | 88 | llm_request,
|
88 | 89 | )
|
89 |
| - async with llm.connect(llm_request) as llm_connection: |
90 |
| - if llm_request.contents: |
91 |
| - # Sends the conversation history to the model. |
92 |
| - with tracer.start_as_current_span('send_data'): |
93 |
| - |
94 |
| - if invocation_context.transcription_cache: |
95 |
| - from . import audio_transcriber |
96 |
| - |
97 |
| - audio_transcriber = audio_transcriber.AudioTranscriber( |
98 |
| - init_client=True |
99 |
| - if invocation_context.run_config.input_audio_transcription |
100 |
| - is None |
101 |
| - else False |
102 |
| - ) |
103 |
| - contents = audio_transcriber.transcribe_file(invocation_context) |
104 |
| - logger.debug('Sending history to model: %s', contents) |
105 |
| - await llm_connection.send_history(contents) |
106 |
| - invocation_context.transcription_cache = None |
107 |
| - trace_send_data(invocation_context, event_id, contents) |
108 |
| - else: |
109 |
| - await llm_connection.send_history(llm_request.contents) |
110 |
| - trace_send_data(invocation_context, event_id, llm_request.contents) |
111 |
| - |
112 |
| - send_task = asyncio.create_task( |
113 |
| - self._send_to_model(llm_connection, invocation_context) |
114 |
| - ) |
115 | 90 |
|
| 91 | + attempt = 1 |
| 92 | + while True: |
116 | 93 | try:
|
117 |
| - async for event in self._receive_from_model( |
118 |
| - llm_connection, |
119 |
| - event_id, |
120 |
| - invocation_context, |
121 |
| - llm_request, |
122 |
| - ): |
123 |
| - # Empty event means the queue is closed. |
124 |
| - if not event: |
125 |
| - break |
126 |
| - logger.debug('Receive new event: %s', event) |
127 |
| - yield event |
128 |
| - # send back the function response |
129 |
| - if event.get_function_responses(): |
130 |
| - logger.debug('Sending back last function response event: %s', event) |
131 |
| - invocation_context.live_request_queue.send_content(event.content) |
132 |
| - if ( |
133 |
| - event.content |
134 |
| - and event.content.parts |
135 |
| - and event.content.parts[0].function_response |
136 |
| - and event.content.parts[0].function_response.name |
137 |
| - == 'transfer_to_agent' |
138 |
| - ): |
139 |
| - await asyncio.sleep(1) |
140 |
| - # cancel the tasks that belongs to the closed connection. |
141 |
| - send_task.cancel() |
142 |
| - await llm_connection.close() |
143 |
| - if ( |
144 |
| - event.content |
145 |
| - and event.content.parts |
146 |
| - and event.content.parts[0].function_response |
147 |
| - and event.content.parts[0].function_response.name |
148 |
| - == 'task_completed' |
149 |
| - ): |
150 |
| - # this is used for sequential agent to signal the end of the agent. |
151 |
| - await asyncio.sleep(1) |
152 |
| - # cancel the tasks that belongs to the closed connection. |
153 |
| - send_task.cancel() |
154 |
| - return |
155 |
| - finally: |
156 |
| - # Clean up |
157 |
| - if not send_task.done(): |
158 |
| - send_task.cancel() |
159 |
| - try: |
160 |
| - await send_task |
161 |
| - except asyncio.CancelledError: |
162 |
| - pass |
| 94 | + # On subsequent attempts, use the saved token to reconnect |
| 95 | + if invocation_context.live_session_resumption_handle: |
| 96 | + logger.info('Attempting to reconnect (Attempt %s)...', attempt) |
| 97 | + attempt += 1 |
| 98 | + if not llm_request.live_connect_config: |
| 99 | + llm_request.live_connect_config = types.LiveConnectConfig() |
| 100 | + llm_request.live_connect_config.session_resumption.handle = ( |
| 101 | + invocation_context.live_session_resumption_handle |
| 102 | + ) |
| 103 | + llm_request.live_connect_config.session_resumption.transparent = True |
| 104 | + |
| 105 | + logger.info( |
| 106 | + 'Establishing live connection for agent: %s', |
| 107 | + invocation_context.agent.name, |
| 108 | + ) |
| 109 | + async with llm.connect(llm_request) as llm_connection: |
| 110 | + if llm_request.contents: |
| 111 | + # Sends the conversation history to the model. |
| 112 | + with tracer.start_as_current_span('send_data'): |
| 113 | + |
| 114 | + if invocation_context.transcription_cache: |
| 115 | + from . import audio_transcriber |
| 116 | + |
| 117 | + audio_transcriber = audio_transcriber.AudioTranscriber( |
| 118 | + init_client=True |
| 119 | + if invocation_context.run_config.input_audio_transcription |
| 120 | + is None |
| 121 | + else False |
| 122 | + ) |
| 123 | + contents = audio_transcriber.transcribe_file(invocation_context) |
| 124 | + logger.debug('Sending history to model: %s', contents) |
| 125 | + await llm_connection.send_history(contents) |
| 126 | + invocation_context.transcription_cache = None |
| 127 | + trace_send_data(invocation_context, event_id, contents) |
| 128 | + else: |
| 129 | + await llm_connection.send_history(llm_request.contents) |
| 130 | + trace_send_data( |
| 131 | + invocation_context, event_id, llm_request.contents |
| 132 | + ) |
| 133 | + |
| 134 | + send_task = asyncio.create_task( |
| 135 | + self._send_to_model(llm_connection, invocation_context) |
| 136 | + ) |
| 137 | + |
| 138 | + try: |
| 139 | + async for event in self._receive_from_model( |
| 140 | + llm_connection, |
| 141 | + event_id, |
| 142 | + invocation_context, |
| 143 | + llm_request, |
| 144 | + ): |
| 145 | + # Empty event means the queue is closed. |
| 146 | + if not event: |
| 147 | + break |
| 148 | + logger.debug('Receive new event: %s', event) |
| 149 | + yield event |
| 150 | + # send back the function response |
| 151 | + if event.get_function_responses(): |
| 152 | + logger.debug( |
| 153 | + 'Sending back last function response event: %s', event |
| 154 | + ) |
| 155 | + invocation_context.live_request_queue.send_content( |
| 156 | + event.content |
| 157 | + ) |
| 158 | + if ( |
| 159 | + event.content |
| 160 | + and event.content.parts |
| 161 | + and event.content.parts[0].function_response |
| 162 | + and event.content.parts[0].function_response.name |
| 163 | + == 'transfer_to_agent' |
| 164 | + ): |
| 165 | + await asyncio.sleep(1) |
| 166 | + # cancel the tasks that belongs to the closed connection. |
| 167 | + send_task.cancel() |
| 168 | + await llm_connection.close() |
| 169 | + if ( |
| 170 | + event.content |
| 171 | + and event.content.parts |
| 172 | + and event.content.parts[0].function_response |
| 173 | + and event.content.parts[0].function_response.name |
| 174 | + == 'task_completed' |
| 175 | + ): |
| 176 | + # this is used for sequential agent to signal the end of the agent. |
| 177 | + await asyncio.sleep(1) |
| 178 | + # cancel the tasks that belongs to the closed connection. |
| 179 | + send_task.cancel() |
| 180 | + return |
| 181 | + finally: |
| 182 | + # Clean up |
| 183 | + if not send_task.done(): |
| 184 | + send_task.cancel() |
| 185 | + try: |
| 186 | + await send_task |
| 187 | + except asyncio.CancelledError: |
| 188 | + pass |
| 189 | + except (ConnectionClosed, ConnectionClosedOK) as e: |
| 190 | + # when the session timeout, it will just close and not throw exception. |
| 191 | + # so this is for bad cases |
| 192 | + logger.error(f'Connection closed: {e}.') |
| 193 | + raise |
| 194 | + except Exception as e: |
| 195 | + logger.error( |
| 196 | + f'An unexpected error occurred in live flow: {e}', exc_info=True |
| 197 | + ) |
| 198 | + raise |
163 | 199 |
|
164 | 200 | async def _send_to_model(
|
165 | 201 | self,
|
@@ -246,6 +282,14 @@ def get_author_for_event(llm_response):
|
246 | 282 | try:
|
247 | 283 | while True:
|
248 | 284 | async for llm_response in llm_connection.receive():
|
| 285 | + if llm_response.live_session_resumption_update: |
| 286 | + logger.info( |
| 287 | + 'Update session resumption hanlde:' |
| 288 | + f' {llm_response.live_session_resumption_update}.' |
| 289 | + ) |
| 290 | + invocation_context.live_session_resumption_handle = ( |
| 291 | + llm_response.live_session_resumption_update.new_handle |
| 292 | + ) |
249 | 293 | model_response_event = Event(
|
250 | 294 | id=Event.new_id(),
|
251 | 295 | invocation_id=invocation_context.invocation_id,
|
|
0 commit comments