diff --git a/framework/py/flwr/supercore/corestate/sql_corestate.py b/framework/py/flwr/supercore/corestate/sql_corestate.py index 59d5ad706004..6d6192f1f409 100644 --- a/framework/py/flwr/supercore/corestate/sql_corestate.py +++ b/framework/py/flwr/supercore/corestate/sql_corestate.py @@ -16,10 +16,12 @@ import secrets +import threading +from logging import DEBUG from typing import cast from sqlalchemy import MetaData, text -from sqlalchemy.exc import IntegrityError +from sqlalchemy.exc import IntegrityError, OperationalError from flwr.common import now from flwr.common.constant import ( @@ -27,6 +29,7 @@ HEARTBEAT_DEFAULT_INTERVAL, HEARTBEAT_PATIENCE, ) +from flwr.common.logger import log from flwr.supercore.sql_mixin import SqlMixin from flwr.supercore.state.schema.corestate_tables import create_corestate_metadata from flwr.supercore.utils import int64_to_uint64, uint64_to_int64 @@ -38,9 +41,13 @@ class SqlCoreState(CoreState, SqlMixin): """SQLAlchemy-based CoreState implementation.""" + _TOKEN_CLEANUP_INTERVAL_SECONDS = 0.5 + def __init__(self, database_path: str, object_store: ObjectStore) -> None: super().__init__(database_path) self._object_store = object_store + self._cleanup_lock = threading.Lock() + self._last_cleanup_timestamp = 0.0 @property def object_store(self) -> ObjectStore: @@ -123,22 +130,60 @@ def _cleanup_expired_tokens(self) -> None: Subclasses can override `_on_tokens_expired` to add custom cleanup logic. """ current = now().timestamp() - - with self.session() as session: - # Delete expired tokens and get their run_ids and active_until timestamps - query = """ - DELETE FROM token_store - WHERE active_until < :current - RETURNING run_id, active_until; - """ - rows = session.execute(text(query), {"current": current}).mappings().all() - expired_records = [ - (int64_to_uint64(row["run_id"]), row["active_until"]) for row in rows - ] - - # Hook for subclasses - if expired_records: - self._on_tokens_expired(expired_records) + if ( + current - self._last_cleanup_timestamp + < self._TOKEN_CLEANUP_INTERVAL_SECONDS + ): + return + + with self._cleanup_lock: + current = now().timestamp() + if ( + current - self._last_cleanup_timestamp + < self._TOKEN_CLEANUP_INTERVAL_SECONDS + ): + return + + # Best-effort cleanup throttling: set timestamp before cleanup so lock + # contention does not cause every request to retry this write path. + self._last_cleanup_timestamp = current + try: + # Super cheap read-first check to avoid entering DELETE write path when + # there are no expired tokens. + has_expired_rows = self.query( + "SELECT 1 FROM token_store WHERE active_until < :current LIMIT 1;", + {"current": current}, + ) + if not has_expired_rows: + return + + with self.session() as session: + # Delete expired tokens and get their run_ids and active_until + # timestamps. + query = """ + DELETE FROM token_store + WHERE active_until < :current + RETURNING run_id, active_until; + """ + rows = ( + session.execute(text(query), {"current": current}) + .mappings() + .all() + ) + expired_records = [ + (int64_to_uint64(row["run_id"]), row["active_until"]) + for row in rows + ] + # Hook for subclasses + if expired_records: + self._on_tokens_expired(expired_records) + except OperationalError: + # Skip cleanup and let the next cleanup tick retry. + log( + DEBUG, + "Skipping token cleanup due to SQLite lock contention", + ) + return def _on_tokens_expired(self, expired_records: list[tuple[int, float]]) -> None: """Handle cleanup of expired tokens.