Skip to content

Commit 4ffbae9

Browse files
authored
feat: Google ADK store signature update (#153)
Google ADK is moving quickly. This change aligns the signature with the latest version
1 parent 7e8c255 commit 4ffbae9

File tree

13 files changed

+754
-606
lines changed

13 files changed

+754
-606
lines changed

sqlspec/adapters/adbc/adk/store.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -639,31 +639,41 @@ def delete_session(self, session_id: str) -> None:
639639
finally:
640640
cursor.close() # type: ignore[no-untyped-call]
641641

642-
def list_sessions(self, app_name: str, user_id: str) -> "list[SessionRecord]":
643-
"""List all sessions for a user in an app.
642+
def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]":
643+
"""List sessions for an app, optionally filtered by user.
644644
645645
Args:
646646
app_name: Application name.
647-
user_id: User identifier.
647+
user_id: User identifier. If None, lists all sessions for the app.
648648
649649
Returns:
650650
List of session records ordered by update_time DESC.
651651
652652
Notes:
653-
Uses composite index on (app_name, user_id).
654-
"""
655-
sql = f"""
656-
SELECT id, app_name, user_id, state, create_time, update_time
657-
FROM {self._session_table}
658-
WHERE app_name = ? AND user_id = ?
659-
ORDER BY update_time DESC
653+
Uses composite index on (app_name, user_id) when user_id is provided.
660654
"""
655+
if user_id is None:
656+
sql = f"""
657+
SELECT id, app_name, user_id, state, create_time, update_time
658+
FROM {self._session_table}
659+
WHERE app_name = ?
660+
ORDER BY update_time DESC
661+
"""
662+
params: tuple[str, ...] = (app_name,)
663+
else:
664+
sql = f"""
665+
SELECT id, app_name, user_id, state, create_time, update_time
666+
FROM {self._session_table}
667+
WHERE app_name = ? AND user_id = ?
668+
ORDER BY update_time DESC
669+
"""
670+
params = (app_name, user_id)
661671

662672
try:
663673
with self._config.provide_connection() as conn:
664674
cursor = conn.cursor()
665675
try:
666-
cursor.execute(sql, (app_name, user_id))
676+
cursor.execute(sql, params)
667677
rows = cursor.fetchall()
668678

669679
return [

sqlspec/adapters/aiosqlite/adk/store.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -342,29 +342,39 @@ async def update_session_state(self, session_id: str, state: "dict[str, Any]") -
342342
await conn.execute(sql, (state_json, now_julian, session_id))
343343
await conn.commit()
344344

345-
async def list_sessions(self, app_name: str, user_id: str) -> "list[SessionRecord]":
346-
"""List all sessions for a user in an app.
345+
async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]":
346+
"""List sessions for an app, optionally filtered by user.
347347
348348
Args:
349349
app_name: Application name.
350-
user_id: User identifier.
350+
user_id: User identifier. If None, lists all sessions for the app.
351351
352352
Returns:
353353
List of session records ordered by update_time DESC.
354354
355355
Notes:
356-
Uses composite index on (app_name, user_id).
357-
"""
358-
sql = f"""
359-
SELECT id, app_name, user_id, state, create_time, update_time
360-
FROM {self._session_table}
361-
WHERE app_name = ? AND user_id = ?
362-
ORDER BY update_time DESC
356+
Uses composite index on (app_name, user_id) when user_id is provided.
363357
"""
358+
if user_id is None:
359+
sql = f"""
360+
SELECT id, app_name, user_id, state, create_time, update_time
361+
FROM {self._session_table}
362+
WHERE app_name = ?
363+
ORDER BY update_time DESC
364+
"""
365+
params: tuple[str, ...] = (app_name,)
366+
else:
367+
sql = f"""
368+
SELECT id, app_name, user_id, state, create_time, update_time
369+
FROM {self._session_table}
370+
WHERE app_name = ? AND user_id = ?
371+
ORDER BY update_time DESC
372+
"""
373+
params = (app_name, user_id)
364374

365375
async with self._config.provide_connection() as conn:
366376
await self._enable_foreign_keys(conn)
367-
cursor = await conn.execute(sql, (app_name, user_id))
377+
cursor = await conn.execute(sql, params)
368378
rows = await cursor.fetchall()
369379

370380
return [

sqlspec/adapters/asyncmy/adk/store.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -326,29 +326,39 @@ async def delete_session(self, session_id: str) -> None:
326326
await cursor.execute(sql, (session_id,))
327327
await conn.commit()
328328

329-
async def list_sessions(self, app_name: str, user_id: str) -> "list[SessionRecord]":
330-
"""List all sessions for a user in an app.
329+
async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]":
330+
"""List sessions for an app, optionally filtered by user.
331331
332332
Args:
333333
app_name: Application name.
334-
user_id: User identifier.
334+
user_id: User identifier. If None, lists all sessions for the app.
335335
336336
Returns:
337337
List of session records ordered by update_time DESC.
338338
339339
Notes:
340-
Uses composite index on (app_name, user_id).
341-
"""
342-
sql = f"""
343-
SELECT id, app_name, user_id, state, create_time, update_time
344-
FROM {self._session_table}
345-
WHERE app_name = %s AND user_id = %s
346-
ORDER BY update_time DESC
340+
Uses composite index on (app_name, user_id) when user_id is provided.
347341
"""
342+
if user_id is None:
343+
sql = f"""
344+
SELECT id, app_name, user_id, state, create_time, update_time
345+
FROM {self._session_table}
346+
WHERE app_name = %s
347+
ORDER BY update_time DESC
348+
"""
349+
params: tuple[str, ...] = (app_name,)
350+
else:
351+
sql = f"""
352+
SELECT id, app_name, user_id, state, create_time, update_time
353+
FROM {self._session_table}
354+
WHERE app_name = %s AND user_id = %s
355+
ORDER BY update_time DESC
356+
"""
357+
params = (app_name, user_id)
348358

349359
try:
350360
async with self._config.provide_connection() as conn, conn.cursor() as cursor:
351-
await cursor.execute(sql, (app_name, user_id))
361+
await cursor.execute(sql, params)
352362
rows = await cursor.fetchall()
353363

354364
return [

sqlspec/adapters/asyncpg/adk/store.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -294,29 +294,39 @@ async def delete_session(self, session_id: str) -> None:
294294
async with self.config.provide_connection() as conn:
295295
await conn.execute(sql, session_id)
296296

297-
async def list_sessions(self, app_name: str, user_id: str) -> "list[SessionRecord]":
298-
"""List all sessions for a user in an app.
297+
async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]":
298+
"""List sessions for an app, optionally filtered by user.
299299
300300
Args:
301301
app_name: Application name.
302-
user_id: User identifier.
302+
user_id: User identifier. If None, lists all sessions for the app.
303303
304304
Returns:
305305
List of session records ordered by update_time DESC.
306306
307307
Notes:
308-
Uses composite index on (app_name, user_id).
309-
"""
310-
sql = f"""
311-
SELECT id, app_name, user_id, state, create_time, update_time
312-
FROM {self._session_table}
313-
WHERE app_name = $1 AND user_id = $2
314-
ORDER BY update_time DESC
308+
Uses composite index on (app_name, user_id) when user_id is provided.
315309
"""
310+
if user_id is None:
311+
sql = f"""
312+
SELECT id, app_name, user_id, state, create_time, update_time
313+
FROM {self._session_table}
314+
WHERE app_name = $1
315+
ORDER BY update_time DESC
316+
"""
317+
params = [app_name]
318+
else:
319+
sql = f"""
320+
SELECT id, app_name, user_id, state, create_time, update_time
321+
FROM {self._session_table}
322+
WHERE app_name = $1 AND user_id = $2
323+
ORDER BY update_time DESC
324+
"""
325+
params = [app_name, user_id]
316326

317327
try:
318328
async with self.config.provide_connection() as conn:
319-
rows = await conn.fetch(sql, app_name, user_id)
329+
rows = await conn.fetch(sql, *params)
320330

321331
return [
322332
SessionRecord(

sqlspec/adapters/bigquery/adk/store.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -351,20 +351,29 @@ async def update_session_state(self, session_id: str, state: "dict[str, Any]") -
351351
"""
352352
await async_(self._update_session_state)(session_id, state)
353353

354-
def _list_sessions(self, app_name: str, user_id: str) -> "list[SessionRecord]":
354+
def _list_sessions(self, app_name: str, user_id: "str | None") -> "list[SessionRecord]":
355355
"""Synchronous implementation of list_sessions."""
356356
table_name = self._get_full_table_name(self._session_table)
357-
sql = f"""
358-
SELECT id, app_name, user_id, JSON_VALUE(state) as state, create_time, update_time
359-
FROM {table_name}
360-
WHERE app_name = @app_name AND user_id = @user_id
361-
ORDER BY update_time DESC
362-
"""
363357

364-
params = [
365-
ScalarQueryParameter("app_name", "STRING", app_name),
366-
ScalarQueryParameter("user_id", "STRING", user_id),
367-
]
358+
if user_id is None:
359+
sql = f"""
360+
SELECT id, app_name, user_id, JSON_VALUE(state) as state, create_time, update_time
361+
FROM {table_name}
362+
WHERE app_name = @app_name
363+
ORDER BY update_time DESC
364+
"""
365+
params = [ScalarQueryParameter("app_name", "STRING", app_name)]
366+
else:
367+
sql = f"""
368+
SELECT id, app_name, user_id, JSON_VALUE(state) as state, create_time, update_time
369+
FROM {table_name}
370+
WHERE app_name = @app_name AND user_id = @user_id
371+
ORDER BY update_time DESC
372+
"""
373+
params = [
374+
ScalarQueryParameter("app_name", "STRING", app_name),
375+
ScalarQueryParameter("user_id", "STRING", user_id),
376+
]
368377

369378
with self._config.provide_connection() as conn:
370379
job_config = QueryJobConfig(query_parameters=params)
@@ -383,18 +392,18 @@ def _list_sessions(self, app_name: str, user_id: str) -> "list[SessionRecord]":
383392
for row in results
384393
]
385394

386-
async def list_sessions(self, app_name: str, user_id: str) -> "list[SessionRecord]":
387-
"""List all sessions for a user in an app.
395+
async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]":
396+
"""List sessions for an app, optionally filtered by user.
388397
389398
Args:
390399
app_name: Application name.
391-
user_id: User identifier.
400+
user_id: User identifier. If None, lists all sessions for the app.
392401
393402
Returns:
394403
List of session records ordered by update_time DESC.
395404
396405
Notes:
397-
Uses clustering on (app_name, user_id) for efficiency.
406+
Uses clustering on (app_name, user_id) when user_id is provided for efficiency.
398407
"""
399408
return await async_(self._list_sessions)(app_name, user_id)
400409

sqlspec/adapters/duckdb/adk/store.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -315,29 +315,39 @@ def delete_session(self, session_id: str) -> None:
315315
conn.execute(delete_session_sql, (session_id,))
316316
conn.commit()
317317

318-
def list_sessions(self, app_name: str, user_id: str) -> "list[SessionRecord]":
319-
"""List all sessions for a user in an app.
318+
def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]":
319+
"""List sessions for an app, optionally filtered by user.
320320
321321
Args:
322322
app_name: Application name.
323-
user_id: User identifier.
323+
user_id: User identifier. If None, lists all sessions for the app.
324324
325325
Returns:
326326
List of session records ordered by update_time DESC.
327327
328328
Notes:
329-
Uses composite index on (app_name, user_id).
330-
"""
331-
sql = f"""
332-
SELECT id, app_name, user_id, state, create_time, update_time
333-
FROM {self._session_table}
334-
WHERE app_name = ? AND user_id = ?
335-
ORDER BY update_time DESC
329+
Uses composite index on (app_name, user_id) when user_id is provided.
336330
"""
331+
if user_id is None:
332+
sql = f"""
333+
SELECT id, app_name, user_id, state, create_time, update_time
334+
FROM {self._session_table}
335+
WHERE app_name = ?
336+
ORDER BY update_time DESC
337+
"""
338+
params: tuple[str, ...] = (app_name,)
339+
else:
340+
sql = f"""
341+
SELECT id, app_name, user_id, state, create_time, update_time
342+
FROM {self._session_table}
343+
WHERE app_name = ? AND user_id = ?
344+
ORDER BY update_time DESC
345+
"""
346+
params = (app_name, user_id)
337347

338348
try:
339349
with self._config.provide_connection() as conn:
340-
cursor = conn.execute(sql, (app_name, user_id))
350+
cursor = conn.execute(sql, params)
341351
rows = cursor.fetchall()
342352

343353
return [

0 commit comments

Comments
 (0)