Skip to content

Commit f9c09ef

Browse files
google-genai-botcopybara-github
authored andcommitted
feat: Support returning all sessions when user_id is none in the request
resolves #3154 PiperOrigin-RevId: 819417330
1 parent 141318f commit f9c09ef

File tree

6 files changed

+40
-147
lines changed

6 files changed

+40
-147
lines changed

src/google/adk/sessions/base_session_service.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -83,18 +83,9 @@ async def get_session(
8383

8484
@abc.abstractmethod
8585
async def list_sessions(
86-
self, *, app_name: str, user_id: Optional[str] = None
86+
self, *, app_name: str, user_id: str
8787
) -> ListSessionsResponse:
88-
"""Lists all the sessions for a user.
89-
90-
Args:
91-
app_name: The name of the app.
92-
user_id: The ID of the user. If not provided, lists all sessions for all
93-
users.
94-
95-
Returns:
96-
A ListSessionsResponse containing the sessions.
97-
"""
88+
"""Lists all the sessions."""
9889

9990
@abc.abstractmethod
10091
async def delete_session(

src/google/adk/sessions/database_session_service.py

Lines changed: 13 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -554,42 +554,30 @@ async def get_session(
554554

555555
@override
556556
async def list_sessions(
557-
self, *, app_name: str, user_id: Optional[str] = None
557+
self, *, app_name: str, user_id: str
558558
) -> ListSessionsResponse:
559559
with self.database_session_factory() as sql_session:
560-
query = sql_session.query(StorageSession).filter(
561-
StorageSession.app_name == app_name
560+
results = (
561+
sql_session.query(StorageSession)
562+
.filter(StorageSession.app_name == app_name)
563+
.filter(StorageSession.user_id == user_id)
564+
.all()
562565
)
563-
if user_id is not None:
564-
query = query.filter(StorageSession.user_id == user_id)
565-
results = query.all()
566566

567-
# Fetch app state from storage
567+
# Fetch states from storage
568568
storage_app_state = sql_session.get(StorageAppState, (app_name))
569-
app_state = storage_app_state.state if storage_app_state else {}
569+
storage_user_state = sql_session.get(
570+
StorageUserState, (app_name, user_id)
571+
)
570572

571-
# Fetch user state(s) from storage
572-
user_states_map = {}
573-
if user_id is not None:
574-
storage_user_state = sql_session.get(
575-
StorageUserState, (app_name, user_id)
576-
)
577-
if storage_user_state:
578-
user_states_map[user_id] = storage_user_state.state
579-
else:
580-
all_user_states_for_app = (
581-
sql_session.query(StorageUserState)
582-
.filter(StorageUserState.app_name == app_name)
583-
.all()
584-
)
585-
for storage_user_state in all_user_states_for_app:
586-
user_states_map[storage_user_state.user_id] = storage_user_state.state
573+
app_state = storage_app_state.state if storage_app_state else {}
574+
user_state = storage_user_state.state if storage_user_state else {}
587575

588576
sessions = []
589577
for storage_session in results:
590578
session_state = storage_session.state
591-
user_state = user_states_map.get(storage_session.user_id, {})
592579
merged_state = _merge_state(app_state, user_state, session_state)
580+
593581
sessions.append(storage_session.to_session(state=merged_state))
594582
return ListSessionsResponse(sessions=sessions)
595583

src/google/adk/sessions/in_memory_session_service.py

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -201,41 +201,31 @@ def _merge_state(
201201

202202
@override
203203
async def list_sessions(
204-
self, *, app_name: str, user_id: Optional[str] = None
204+
self, *, app_name: str, user_id: str
205205
) -> ListSessionsResponse:
206206
return self._list_sessions_impl(app_name=app_name, user_id=user_id)
207207

208208
def list_sessions_sync(
209-
self, *, app_name: str, user_id: Optional[str] = None
209+
self, *, app_name: str, user_id: str
210210
) -> ListSessionsResponse:
211211
logger.warning('Deprecated. Please migrate to the async method.')
212212
return self._list_sessions_impl(app_name=app_name, user_id=user_id)
213213

214214
def _list_sessions_impl(
215-
self, *, app_name: str, user_id: Optional[str] = None
215+
self, *, app_name: str, user_id: str
216216
) -> ListSessionsResponse:
217217
empty_response = ListSessionsResponse()
218218
if app_name not in self.sessions:
219219
return empty_response
220-
if user_id is not None and user_id not in self.sessions[app_name]:
220+
if user_id not in self.sessions[app_name]:
221221
return empty_response
222222

223223
sessions_without_events = []
224-
225-
if user_id is None:
226-
for user_id in self.sessions[app_name]:
227-
for session_id in self.sessions[app_name][user_id]:
228-
session = self.sessions[app_name][user_id][session_id]
229-
copied_session = copy.deepcopy(session)
230-
copied_session.events = []
231-
copied_session = self._merge_state(app_name, user_id, copied_session)
232-
sessions_without_events.append(copied_session)
233-
else:
234-
for session in self.sessions[app_name][user_id].values():
235-
copied_session = copy.deepcopy(session)
236-
copied_session.events = []
237-
copied_session = self._merge_state(app_name, user_id, copied_session)
238-
sessions_without_events.append(copied_session)
224+
for session in self.sessions[app_name][user_id].values():
225+
copied_session = copy.deepcopy(session)
226+
copied_session.events = []
227+
copied_session = self._merge_state(app_name, user_id, copied_session)
228+
sessions_without_events.append(copied_session)
239229
return ListSessionsResponse(sessions=sessions_without_events)
240230

241231
@override

src/google/adk/sessions/vertex_ai_session_service.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -200,25 +200,22 @@ async def get_session(
200200

201201
@override
202202
async def list_sessions(
203-
self, *, app_name: str, user_id: Optional[str] = None
203+
self, *, app_name: str, user_id: str
204204
) -> ListSessionsResponse:
205205
reasoning_engine_id = self._get_reasoning_engine_id(app_name)
206206
api_client = self._get_api_client()
207207

208208
sessions = []
209-
config = {}
210-
if user_id is not None:
211-
config['filter'] = f'user_id="{user_id}"'
212209
sessions_iterator = api_client.agent_engines.sessions.list(
213210
name=f'reasoningEngines/{reasoning_engine_id}',
214-
config=config,
211+
config={'filter': f'user_id="{user_id}"'},
215212
)
216213

217214
for api_session in sessions_iterator:
218215
sessions.append(
219216
Session(
220217
app_name=app_name,
221-
user_id=api_session.user_id,
218+
user_id=user_id,
222219
id=api_session.name.split('/')[-1],
223220
state=getattr(api_session, 'session_state', None) or {},
224221
last_update_time=api_session.update_time.timestamp(),

tests/unittests/sessions/test_session_service.py

Lines changed: 3 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -116,70 +116,9 @@ async def test_create_and_list_sessions(service_type):
116116
app_name=app_name, user_id=user_id
117117
)
118118
sessions = list_sessions_response.sessions
119-
assert len(sessions) == len(session_ids)
120-
assert {s.id for s in sessions} == set(session_ids)
121-
for session in sessions:
122-
assert session.state == {'key': 'value' + session.id}
123-
124-
125-
@pytest.mark.asyncio
126-
@pytest.mark.parametrize(
127-
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
128-
)
129-
async def test_list_sessions_all_users(service_type):
130-
session_service = get_session_service(service_type)
131-
app_name = 'my_app'
132-
user_id_1 = 'user1'
133-
user_id_2 = 'user2'
134-
135-
await session_service.create_session(
136-
app_name=app_name,
137-
user_id=user_id_1,
138-
session_id='session1a',
139-
state={'key': 'value1a'},
140-
)
141-
await session_service.create_session(
142-
app_name=app_name,
143-
user_id=user_id_1,
144-
session_id='session1b',
145-
state={'key': 'value1b'},
146-
)
147-
await session_service.create_session(
148-
app_name=app_name,
149-
user_id=user_id_2,
150-
session_id='session2a',
151-
state={'key': 'value2a'},
152-
)
153-
154-
# List sessions for user1
155-
list_sessions_response_1 = await session_service.list_sessions(
156-
app_name=app_name, user_id=user_id_1
157-
)
158-
sessions_1 = list_sessions_response_1.sessions
159-
assert len(sessions_1) == 2
160-
assert {s.id for s in sessions_1} == {'session1a', 'session1b'}
161-
for session in sessions_1:
162-
if session.id == 'session1a':
163-
assert session.state == {'key': 'value1a'}
164-
else:
165-
assert session.state == {'key': 'value1b'}
166-
167-
# List sessions for user2
168-
list_sessions_response_2 = await session_service.list_sessions(
169-
app_name=app_name, user_id=user_id_2
170-
)
171-
sessions_2 = list_sessions_response_2.sessions
172-
assert len(sessions_2) == 1
173-
assert sessions_2[0].id == 'session2a'
174-
assert sessions_2[0].state == {'key': 'value2a'}
175-
176-
# List sessions for all users
177-
list_sessions_response_all = await session_service.list_sessions(
178-
app_name=app_name, user_id=None
179-
)
180-
sessions_all = list_sessions_response_all.sessions
181-
assert len(sessions_all) == 3
182-
assert {s.id for s in sessions_all} == {'session1a', 'session1b', 'session2a'}
119+
for i in range(len(sessions)):
120+
assert sessions[i].id == session_ids[i]
121+
assert sessions[i].state == {'key': 'value' + session_ids[i]}
183122

184123

185124
@pytest.mark.asyncio

tests/unittests/sessions/test_vertex_ai_session_service.py

Lines changed: 10 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -252,22 +252,19 @@ def _get_session(self, name: str):
252252
def _list_sessions(self, name: str, config: dict[str, Any]):
253253
filter_val = config.get('filter', '')
254254
user_id_match = re.search(r'user_id="([^"]+)"', filter_val)
255-
if user_id_match:
256-
user_id = user_id_match.group(1)
257-
if user_id == 'user_with_pages':
258-
return [
259-
_convert_to_object(MOCK_SESSION_JSON_PAGE1),
260-
_convert_to_object(MOCK_SESSION_JSON_PAGE2),
261-
]
255+
if not user_id_match:
256+
raise ValueError(f'Could not find user_id in filter: {filter_val}')
257+
user_id = user_id_match.group(1)
258+
259+
if user_id == 'user_with_pages':
262260
return [
263-
_convert_to_object(session)
264-
for session in self.session_dict.values()
265-
if session['user_id'] == user_id
261+
_convert_to_object(MOCK_SESSION_JSON_PAGE1),
262+
_convert_to_object(MOCK_SESSION_JSON_PAGE2),
266263
]
267-
268-
# No user filter, return all sessions
269264
return [
270-
_convert_to_object(session) for session in self.session_dict.values()
265+
_convert_to_object(session)
266+
for session in self.session_dict.values()
267+
if session['user_id'] == user_id
271268
]
272269

273270
def _delete_session(self, name: str):
@@ -478,15 +475,6 @@ async def test_list_sessions_with_pagination():
478475
assert sessions.sessions[1].id == 'page2'
479476

480477

481-
@pytest.mark.asyncio
482-
@pytest.mark.usefixtures('mock_get_api_client')
483-
async def test_list_sessions_all_users():
484-
session_service = mock_vertex_ai_session_service()
485-
sessions = await session_service.list_sessions(app_name='123', user_id=None)
486-
assert len(sessions.sessions) == 5
487-
assert {s.id for s in sessions.sessions} == {'1', '2', '3', 'page1', 'page2'}
488-
489-
490478
@pytest.mark.asyncio
491479
@pytest.mark.usefixtures('mock_get_api_client')
492480
async def test_create_session():

0 commit comments

Comments
 (0)