Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 62 additions & 17 deletions framework/py/flwr/supercore/corestate/sql_corestate.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,20 @@


import secrets
import threading
from logging import DEBUG
from typing import cast

from sqlalchemy import MetaData, text
from sqlalchemy.exc import IntegrityError
from sqlalchemy.exc import IntegrityError, OperationalError

from flwr.common import now
from flwr.common.constant import (
FLWR_APP_TOKEN_LENGTH,
HEARTBEAT_DEFAULT_INTERVAL,
HEARTBEAT_PATIENCE,
)
from flwr.common.logger import log
from flwr.supercore.sql_mixin import SqlMixin
from flwr.supercore.state.schema.corestate_tables import create_corestate_metadata
from flwr.supercore.utils import int64_to_uint64, uint64_to_int64
Expand All @@ -38,9 +41,13 @@
class SqlCoreState(CoreState, SqlMixin):
"""SQLAlchemy-based CoreState implementation."""

_TOKEN_CLEANUP_INTERVAL_SECONDS = 0.5

def __init__(self, database_path: str, object_store: ObjectStore) -> None:
super().__init__(database_path)
self._object_store = object_store
self._cleanup_lock = threading.Lock()
self._last_cleanup_timestamp = 0.0

@property
def object_store(self) -> ObjectStore:
Expand Down Expand Up @@ -123,22 +130,60 @@ def _cleanup_expired_tokens(self) -> None:
Subclasses can override `_on_tokens_expired` to add custom cleanup logic.
"""
current = now().timestamp()

with self.session() as session:
# Delete expired tokens and get their run_ids and active_until timestamps
query = """
DELETE FROM token_store
WHERE active_until < :current
RETURNING run_id, active_until;
"""
rows = session.execute(text(query), {"current": current}).mappings().all()
expired_records = [
(int64_to_uint64(row["run_id"]), row["active_until"]) for row in rows
]

# Hook for subclasses
if expired_records:
self._on_tokens_expired(expired_records)
if (
current - self._last_cleanup_timestamp
< self._TOKEN_CLEANUP_INTERVAL_SECONDS
):
return

with self._cleanup_lock:
current = now().timestamp()
if (
current - self._last_cleanup_timestamp
< self._TOKEN_CLEANUP_INTERVAL_SECONDS
):
return

# Best-effort cleanup throttling: set timestamp before cleanup so lock
# contention does not cause every request to retry this write path.
self._last_cleanup_timestamp = current
try:
# Super cheap read-first check to avoid entering DELETE write path when
# there are no expired tokens.
has_expired_rows = self.query(
"SELECT 1 FROM token_store WHERE active_until < :current LIMIT 1;",
{"current": current},
)
if not has_expired_rows:
return

with self.session() as session:
# Delete expired tokens and get their run_ids and active_until
# timestamps.
query = """
DELETE FROM token_store
WHERE active_until < :current
RETURNING run_id, active_until;
"""
rows = (
session.execute(text(query), {"current": current})
.mappings()
.all()
)
expired_records = [
(int64_to_uint64(row["run_id"]), row["active_until"])
for row in rows
]
# Hook for subclasses
if expired_records:
self._on_tokens_expired(expired_records)
except OperationalError:
# Skip cleanup and let the next cleanup tick retry.
log(
DEBUG,
"Skipping token cleanup due to SQLite lock contention",
)
return

def _on_tokens_expired(self, expired_records: list[tuple[int, float]]) -> None:
"""Handle cleanup of expired tokens.
Expand Down