Skip to content

Commit 314d6a4

Browse files
DeanChensjcopybara-github
authored andcommitted
fix: Return session state in list_session API endpoint
Resolves #2193 Resolves #781 PiperOrigin-RevId: 789143973
1 parent 247fd20 commit 314d6a4

File tree

4 files changed

+29
-8
lines changed

4 files changed

+29
-8
lines changed

src/google/adk/sessions/database_session_service.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -515,9 +515,22 @@ async def list_sessions(
515515
.filter(StorageSession.user_id == user_id)
516516
.all()
517517
)
518+
519+
# Fetch states from storage
520+
storage_app_state = sql_session.get(StorageAppState, (app_name))
521+
storage_user_state = sql_session.get(
522+
StorageUserState, (app_name, user_id)
523+
)
524+
525+
app_state = storage_app_state.state if storage_app_state else {}
526+
user_state = storage_user_state.state if storage_user_state else {}
527+
518528
sessions = []
519529
for storage_session in results:
520-
sessions.append(storage_session.to_session())
530+
session_state = storage_session.state
531+
merged_state = _merge_state(app_state, user_state, session_state)
532+
533+
sessions.append(storage_session.to_session(state=merged_state))
521534
return ListSessionsResponse(sessions=sessions)
522535

523536
@override

src/google/adk/sessions/in_memory_session_service.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ def _list_sessions_impl(
224224
for session in self.sessions[app_name][user_id].values():
225225
copied_session = copy.deepcopy(session)
226226
copied_session.events = []
227-
copied_session.state = {}
227+
copied_session = self._merge_state(app_name, user_id, copied_session)
228228
sessions_without_events.append(copied_session)
229229
return ListSessionsResponse(sessions=sessions_without_events)
230230

src/google/adk/sessions/vertex_ai_session_service.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -280,24 +280,28 @@ async def list_sessions(
280280
parsed_user_id = urllib.parse.quote(f'''"{user_id}"''', safe='')
281281
path = path + f'?filter=user_id={parsed_user_id}'
282282

283-
api_response = await api_client.async_request(
283+
list_sessions_api_response = await api_client.async_request(
284284
http_method='GET',
285285
path=path,
286286
request_dict={},
287287
)
288-
api_response = _convert_api_response(api_response)
288+
list_sessions_api_response = _convert_api_response(
289+
list_sessions_api_response
290+
)
289291

290292
# Handles empty response case
291-
if not api_response or api_response.get('httpHeaders', None):
293+
if not list_sessions_api_response or list_sessions_api_response.get(
294+
'httpHeaders', None
295+
):
292296
return ListSessionsResponse()
293297

294298
sessions = []
295-
for api_session in api_response['sessions']:
299+
for api_session in list_sessions_api_response['sessions']:
296300
session = Session(
297301
app_name=app_name,
298302
user_id=user_id,
299303
id=api_session['name'].split('/')[-1],
300-
state={},
304+
state=api_session.get('sessionState', {}),
301305
last_update_time=isoparse(api_session['updateTime']).timestamp(),
302306
)
303307
sessions.append(session)

tests/unittests/sessions/test_session_service.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,10 @@ async def test_create_and_list_sessions(service_type):
106106
session_ids = ['session' + str(i) for i in range(5)]
107107
for session_id in session_ids:
108108
await session_service.create_session(
109-
app_name=app_name, user_id=user_id, session_id=session_id
109+
app_name=app_name,
110+
user_id=user_id,
111+
session_id=session_id,
112+
state={'key': 'value' + session_id},
110113
)
111114

112115
list_sessions_response = await session_service.list_sessions(
@@ -115,6 +118,7 @@ async def test_create_and_list_sessions(service_type):
115118
sessions = list_sessions_response.sessions
116119
for i in range(len(sessions)):
117120
assert sessions[i].id == session_ids[i]
121+
assert sessions[i].state == {'key': 'value' + session_ids[i]}
118122

119123

120124
@pytest.mark.asyncio

0 commit comments

Comments
 (0)