Skip to content

Commit 2f00626

Browse files
DeanChensjcopybara-github
authored andcommitted
chore: fix python format.
PiperOrigin-RevId: 759674648
1 parent d0f117e commit 2f00626

File tree

1 file changed

+38
-38
lines changed

1 file changed

+38
-38
lines changed

src/google/adk/sessions/database_session_service.py

Lines changed: 38 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,7 @@
6161

6262

6363
class DynamicJSON(TypeDecorator):
64-
"""A JSON-like type that uses JSONB on PostgreSQL and TEXT with JSON
65-
66-
serialization for other databases.
67-
"""
64+
"""A JSON-like type that uses JSONB on PostgreSQL and TEXT with JSON serialization for other databases."""
6865

6966
impl = Text # Default implementation is TEXT
7067

@@ -242,10 +239,7 @@ class DatabaseSessionService(BaseSessionService):
242239
"""A session service that uses a database for storage."""
243240

244241
def __init__(self, db_url: str):
245-
"""
246-
Args:
247-
db_url: The database URL to connect to.
248-
"""
242+
"""Initializes the database session service with a database URL."""
249243
# 1. Create DB engine for db connection
250244
# 2. Create all tables based on schema
251245
# 3. Initialize all properties
@@ -274,7 +268,7 @@ def __init__(self, db_url: str):
274268
self.inspector = inspect(self.db_engine)
275269

276270
# DB session factory method
277-
self.DatabaseSessionFactory: sessionmaker[DatabaseSessionFactory] = (
271+
self.database_session_factory: sessionmaker[DatabaseSessionFactory] = (
278272
sessionmaker(bind=self.db_engine)
279273
)
280274

@@ -297,11 +291,11 @@ async def create_session(
297291
# 4. Build the session object with generated id
298292
# 5. Return the session
299293

300-
with self.DatabaseSessionFactory() as sessionFactory:
294+
with self.database_session_factory() as session_factory:
301295

302296
# Fetch app and user states from storage
303-
storage_app_state = sessionFactory.get(StorageAppState, (app_name))
304-
storage_user_state = sessionFactory.get(
297+
storage_app_state = session_factory.get(StorageAppState, (app_name))
298+
storage_user_state = session_factory.get(
305299
StorageUserState, (app_name, user_id)
306300
)
307301

@@ -311,12 +305,12 @@ async def create_session(
311305
# Create state tables if not exist
312306
if not storage_app_state:
313307
storage_app_state = StorageAppState(app_name=app_name, state={})
314-
sessionFactory.add(storage_app_state)
308+
session_factory.add(storage_app_state)
315309
if not storage_user_state:
316310
storage_user_state = StorageUserState(
317311
app_name=app_name, user_id=user_id, state={}
318312
)
319-
sessionFactory.add(storage_user_state)
313+
session_factory.add(storage_user_state)
320314

321315
# Extract state deltas
322316
app_state_delta, user_state_delta, session_state = _extract_state_delta(
@@ -340,10 +334,10 @@ async def create_session(
340334
id=session_id,
341335
state=session_state,
342336
)
343-
sessionFactory.add(storage_session)
344-
sessionFactory.commit()
337+
session_factory.add(storage_session)
338+
session_factory.commit()
345339

346-
sessionFactory.refresh(storage_session)
340+
session_factory.refresh(storage_session)
347341

348342
# Merge states for response
349343
merged_state = _merge_state(app_state, user_state, session_state)
@@ -368,31 +362,37 @@ async def get_session(
368362
# 1. Get the storage session entry from session table
369363
# 2. Get all the events based on session id and filtering config
370364
# 3. Convert and return the session
371-
with self.DatabaseSessionFactory() as sessionFactory:
372-
storage_session = sessionFactory.get(
365+
with self.database_session_factory() as session_factory:
366+
storage_session = session_factory.get(
373367
StorageSession, (app_name, user_id, session_id)
374368
)
375369
if storage_session is None:
376370
return None
377-
371+
378372
if config and config.after_timestamp:
379-
after_dt = datetime.fromtimestamp(config.after_timestamp, tz=timezone.utc)
373+
after_dt = datetime.fromtimestamp(
374+
config.after_timestamp, tz=timezone.utc
375+
)
380376
timestamp_filter = StorageEvent.timestamp > after_dt
381377
else:
382378
timestamp_filter = True
383379

384380
storage_events = (
385-
sessionFactory.query(StorageEvent)
381+
session_factory.query(StorageEvent)
386382
.filter(StorageEvent.session_id == storage_session.id)
387-
.filter(timestamp_filter)
383+
.filter(timestamp_filter)
388384
.order_by(StorageEvent.timestamp.asc())
389-
.limit(config.num_recent_events if config and config.num_recent_events else None)
385+
.limit(
386+
config.num_recent_events
387+
if config and config.num_recent_events
388+
else None
389+
)
390390
.all()
391391
)
392392

393393
# Fetch states from storage
394-
storage_app_state = sessionFactory.get(StorageAppState, (app_name))
395-
storage_user_state = sessionFactory.get(
394+
storage_app_state = session_factory.get(StorageAppState, (app_name))
395+
storage_user_state = session_factory.get(
396396
StorageUserState, (app_name, user_id)
397397
)
398398

@@ -436,9 +436,9 @@ async def get_session(
436436
async def list_sessions(
437437
self, *, app_name: str, user_id: str
438438
) -> ListSessionsResponse:
439-
with self.DatabaseSessionFactory() as sessionFactory:
439+
with self.database_session_factory() as session_factory:
440440
results = (
441-
sessionFactory.query(StorageSession)
441+
session_factory.query(StorageSession)
442442
.filter(StorageSession.app_name == app_name)
443443
.filter(StorageSession.user_id == user_id)
444444
.all()
@@ -459,14 +459,14 @@ async def list_sessions(
459459
async def delete_session(
460460
self, app_name: str, user_id: str, session_id: str
461461
) -> None:
462-
with self.DatabaseSessionFactory() as sessionFactory:
462+
with self.database_session_factory() as session_factory:
463463
stmt = delete(StorageSession).where(
464464
StorageSession.app_name == app_name,
465465
StorageSession.user_id == user_id,
466466
StorageSession.id == session_id,
467467
)
468-
sessionFactory.execute(stmt)
469-
sessionFactory.commit()
468+
session_factory.execute(stmt)
469+
session_factory.commit()
470470

471471
@override
472472
async def append_event(self, session: Session, event: Event) -> Event:
@@ -478,8 +478,8 @@ async def append_event(self, session: Session, event: Event) -> Event:
478478
# 1. Check if timestamp is stale
479479
# 2. Update session attributes based on event config
480480
# 3. Store event to table
481-
with self.DatabaseSessionFactory() as sessionFactory:
482-
storage_session = sessionFactory.get(
481+
with self.database_session_factory() as session_factory:
482+
storage_session = session_factory.get(
483483
StorageSession, (session.app_name, session.user_id, session.id)
484484
)
485485

@@ -493,10 +493,10 @@ async def append_event(self, session: Session, event: Event) -> Event:
493493
)
494494

495495
# Fetch states from storage
496-
storage_app_state = sessionFactory.get(
496+
storage_app_state = session_factory.get(
497497
StorageAppState, (session.app_name)
498498
)
499-
storage_user_state = sessionFactory.get(
499+
storage_user_state = session_factory.get(
500500
StorageUserState, (session.app_name, session.user_id)
501501
)
502502

@@ -545,10 +545,10 @@ async def append_event(self, session: Session, event: Event) -> Event:
545545
if event.content:
546546
storage_event.content = _session_util.encode_content(event.content)
547547

548-
sessionFactory.add(storage_event)
548+
session_factory.add(storage_event)
549549

550-
sessionFactory.commit()
551-
sessionFactory.refresh(storage_session)
550+
session_factory.commit()
551+
session_factory.refresh(storage_session)
552552

553553
# Update timestamp with commit time
554554
session.last_update_time = storage_session.update_time.timestamp()

0 commit comments

Comments
 (0)