Skip to content

Commit 71fbc92

Browse files
hangfeicopybara-github
authored andcommitted
feat: Implement Live Session Resumption
Previous implementation doesn't pass the actual handle to server. Now we cache the handle and pass it over when reconnection happens. To enable: run_config = RunConfig( session_resumption=types.SessionResumptionConfig(transparent=True) ) PiperOrigin-RevId: 791308462
1 parent 423542a commit 71fbc92

File tree

9 files changed

+333
-76
lines changed

9 files changed

+333
-76
lines changed

contributing/samples/live_bidi_streaming_multi_agent/agent.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,8 @@ def get_current_weather(location: str):
100100

101101
root_agent = Agent(
102102
# find supported models here: https://google.github.io/adk-docs/get-started/streaming/quickstart-streaming/
103-
# model='gemini-live-2.5-flash-preview-native-audio', # for Vertex project
104-
model="gemini-live-2.5-flash-preview", # for AI studio key
103+
model="gemini-2.0-flash-live-preview-04-09", # for Vertex project
104+
# model="gemini-live-2.5-flash-preview", # for AI studio key
105105
name="root_agent",
106106
instruction="""
107107
You are a helpful assistant that can check time, roll dice and check if numbers are prime.

contributing/samples/live_bidi_streaming_tools_agent/agent.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,9 @@ def stop_streaming(function_name: str):
121121

122122

123123
root_agent = Agent(
124-
model="gemini-live-2.5-flash-preview",
124+
# find supported models here: https://google.github.io/adk-docs/get-started/streaming/quickstart-streaming/
125+
model="gemini-2.0-flash-live-preview-04-09", # for Vertex project
126+
# model="gemini-live-2.5-flash-preview", # for AI studio key
125127
name="video_streaming_agent",
126128
instruction="""
127129
You are a monitoring agent. You can do video monitoring and stock price monitoring

contributing/samples/live_tool_callbacks_agent/agent.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,8 +217,9 @@ async def after_tool_async_callback(
217217

218218
# Create the agent with tool callbacks
219219
root_agent = Agent(
220+
# find supported models here: https://google.github.io/adk-docs/get-started/streaming/quickstart-streaming/
220221
model="gemini-2.0-flash-live-preview-04-09", # for Vertex project
221-
# model="gemini-2.0-flash-live-001", # for AI studio key
222+
# model="gemini-live-2.5-flash-preview", # for AI studio key
222223
name="tool_callbacks_agent",
223224
description=(
224225
"Live streaming agent that demonstrates tool callbacks functionality. "

src/google/adk/agents/invocation_context.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,9 @@ class InvocationContext(BaseModel):
153153
transcription_cache: Optional[list[TranscriptionEntry]] = None
154154
"""Caches necessary data, audio or contents, that are needed by transcription."""
155155

156+
live_session_resumption_handle: Optional[str] = None
157+
"""The handle for live session resumption."""
158+
156159
run_config: Optional[RunConfig] = None
157160
"""Configurations for live agents under this invocation."""
158161

src/google/adk/flows/llm_flows/base_llm_flow.py

Lines changed: 116 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from typing import TYPE_CHECKING
2626

2727
from google.genai import types
28+
from websockets.exceptions import ConnectionClosed
2829
from websockets.exceptions import ConnectionClosedOK
2930

3031
from . import functions
@@ -86,80 +87,115 @@ async def run_live(
8687
invocation_context.agent.name,
8788
llm_request,
8889
)
89-
async with llm.connect(llm_request) as llm_connection:
90-
if llm_request.contents:
91-
# Sends the conversation history to the model.
92-
with tracer.start_as_current_span('send_data'):
93-
94-
if invocation_context.transcription_cache:
95-
from . import audio_transcriber
96-
97-
audio_transcriber = audio_transcriber.AudioTranscriber(
98-
init_client=True
99-
if invocation_context.run_config.input_audio_transcription
100-
is None
101-
else False
102-
)
103-
contents = audio_transcriber.transcribe_file(invocation_context)
104-
logger.debug('Sending history to model: %s', contents)
105-
await llm_connection.send_history(contents)
106-
invocation_context.transcription_cache = None
107-
trace_send_data(invocation_context, event_id, contents)
108-
else:
109-
await llm_connection.send_history(llm_request.contents)
110-
trace_send_data(invocation_context, event_id, llm_request.contents)
111-
112-
send_task = asyncio.create_task(
113-
self._send_to_model(llm_connection, invocation_context)
114-
)
11590

91+
attempt = 1
92+
while True:
11693
try:
117-
async for event in self._receive_from_model(
118-
llm_connection,
119-
event_id,
120-
invocation_context,
121-
llm_request,
122-
):
123-
# Empty event means the queue is closed.
124-
if not event:
125-
break
126-
logger.debug('Receive new event: %s', event)
127-
yield event
128-
# send back the function response
129-
if event.get_function_responses():
130-
logger.debug('Sending back last function response event: %s', event)
131-
invocation_context.live_request_queue.send_content(event.content)
132-
if (
133-
event.content
134-
and event.content.parts
135-
and event.content.parts[0].function_response
136-
and event.content.parts[0].function_response.name
137-
== 'transfer_to_agent'
138-
):
139-
await asyncio.sleep(1)
140-
# cancel the tasks that belongs to the closed connection.
141-
send_task.cancel()
142-
await llm_connection.close()
143-
if (
144-
event.content
145-
and event.content.parts
146-
and event.content.parts[0].function_response
147-
and event.content.parts[0].function_response.name
148-
== 'task_completed'
149-
):
150-
# this is used for sequential agent to signal the end of the agent.
151-
await asyncio.sleep(1)
152-
# cancel the tasks that belongs to the closed connection.
153-
send_task.cancel()
154-
return
155-
finally:
156-
# Clean up
157-
if not send_task.done():
158-
send_task.cancel()
159-
try:
160-
await send_task
161-
except asyncio.CancelledError:
162-
pass
94+
# On subsequent attempts, use the saved token to reconnect
95+
if invocation_context.live_session_resumption_handle:
96+
logger.info('Attempting to reconnect (Attempt %s)...', attempt)
97+
attempt += 1
98+
if not llm_request.live_connect_config:
99+
llm_request.live_connect_config = types.LiveConnectConfig()
100+
llm_request.live_connect_config.session_resumption.handle = (
101+
invocation_context.live_session_resumption_handle
102+
)
103+
llm_request.live_connect_config.session_resumption.transparent = True
104+
105+
logger.info(
106+
'Establishing live connection for agent: %s',
107+
invocation_context.agent.name,
108+
)
109+
async with llm.connect(llm_request) as llm_connection:
110+
if llm_request.contents:
111+
# Sends the conversation history to the model.
112+
with tracer.start_as_current_span('send_data'):
113+
114+
if invocation_context.transcription_cache:
115+
from . import audio_transcriber
116+
117+
audio_transcriber = audio_transcriber.AudioTranscriber(
118+
init_client=True
119+
if invocation_context.run_config.input_audio_transcription
120+
is None
121+
else False
122+
)
123+
contents = audio_transcriber.transcribe_file(invocation_context)
124+
logger.debug('Sending history to model: %s', contents)
125+
await llm_connection.send_history(contents)
126+
invocation_context.transcription_cache = None
127+
trace_send_data(invocation_context, event_id, contents)
128+
else:
129+
await llm_connection.send_history(llm_request.contents)
130+
trace_send_data(
131+
invocation_context, event_id, llm_request.contents
132+
)
133+
134+
send_task = asyncio.create_task(
135+
self._send_to_model(llm_connection, invocation_context)
136+
)
137+
138+
try:
139+
async for event in self._receive_from_model(
140+
llm_connection,
141+
event_id,
142+
invocation_context,
143+
llm_request,
144+
):
145+
# Empty event means the queue is closed.
146+
if not event:
147+
break
148+
logger.debug('Receive new event: %s', event)
149+
yield event
150+
# send back the function response
151+
if event.get_function_responses():
152+
logger.debug(
153+
'Sending back last function response event: %s', event
154+
)
155+
invocation_context.live_request_queue.send_content(
156+
event.content
157+
)
158+
if (
159+
event.content
160+
and event.content.parts
161+
and event.content.parts[0].function_response
162+
and event.content.parts[0].function_response.name
163+
== 'transfer_to_agent'
164+
):
165+
await asyncio.sleep(1)
166+
# cancel the tasks that belongs to the closed connection.
167+
send_task.cancel()
168+
await llm_connection.close()
169+
if (
170+
event.content
171+
and event.content.parts
172+
and event.content.parts[0].function_response
173+
and event.content.parts[0].function_response.name
174+
== 'task_completed'
175+
):
176+
# this is used for sequential agent to signal the end of the agent.
177+
await asyncio.sleep(1)
178+
# cancel the tasks that belongs to the closed connection.
179+
send_task.cancel()
180+
return
181+
finally:
182+
# Clean up
183+
if not send_task.done():
184+
send_task.cancel()
185+
try:
186+
await send_task
187+
except asyncio.CancelledError:
188+
pass
189+
except (ConnectionClosed, ConnectionClosedOK) as e:
190+
# when the session timeout, it will just close and not throw exception.
191+
# so this is for bad cases
192+
logger.error(f'Connection closed: {e}.')
193+
raise
194+
except Exception as e:
195+
logger.error(
196+
f'An unexpected error occurred in live flow: {e}', exc_info=True
197+
)
198+
raise
163199

164200
async def _send_to_model(
165201
self,
@@ -246,6 +282,14 @@ def get_author_for_event(llm_response):
246282
try:
247283
while True:
248284
async for llm_response in llm_connection.receive():
285+
if llm_response.live_session_resumption_update:
286+
logger.info(
287+
'Update session resumption hanlde:'
288+
f' {llm_response.live_session_resumption_update}.'
289+
)
290+
invocation_context.live_session_resumption_handle = (
291+
llm_response.live_session_resumption_update.new_handle
292+
)
249293
model_response_event = Event(
250294
id=Event.new_id(),
251295
invocation_id=invocation_context.invocation_id,

src/google/adk/models/gemini_llm_connection.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,13 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
219219
for function_call in message.tool_call.function_calls
220220
]
221221
yield LlmResponse(content=types.Content(role='model', parts=parts))
222+
if message.session_resumption_update:
223+
logger.info('Redeived session reassumption message: %s', message)
224+
yield (
225+
LlmResponse(
226+
live_session_resumption_update=message.session_resumption_update
227+
)
228+
)
222229

223230
async def close(self):
224231
"""Closes the llm server connection."""

src/google/adk/models/google_llm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,7 @@ async def connect(self, llm_request: LlmRequest) -> BaseLlmConnection:
289289
],
290290
)
291291
llm_request.live_connect_config.tools = llm_request.config.tools
292+
logger.info('Connecting to live with llm_request:%s', llm_request)
292293
async with self._live_api_client.aio.live.connect(
293294
model=llm_request.model, config=llm_request.live_connect_config
294295
) as live_session:

src/google/adk/models/llm_response.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,11 @@ class LlmResponse(BaseModel):
8989
usage_metadata: Optional[types.GenerateContentResponseUsageMetadata] = None
9090
"""The usage metadata of the LlmResponse"""
9191

92+
live_session_resumption_update: Optional[
93+
types.LiveServerSessionResumptionUpdate
94+
] = None
95+
"""The session resumption update of the LlmResponse"""
96+
9297
@staticmethod
9398
def create(
9499
generate_content_response: types.GenerateContentResponse,

0 commit comments

Comments
 (0)