Skip to content
247 changes: 184 additions & 63 deletions src/crawlee/storage_clients/_sql/_client_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@

from abc import ABC, abstractmethod
from contextlib import asynccontextmanager
from datetime import datetime, timezone
from datetime import datetime, timedelta, timezone
from logging import getLogger
from typing import TYPE_CHECKING, Any, ClassVar, TypedDict, cast, overload

from sqlalchemy import delete, select, text, update
from sqlalchemy import CursorResult, delete, select, text, update
from sqlalchemy import func as sql_func
from sqlalchemy.dialects.postgresql import insert as pg_insert
from sqlalchemy.dialects.sqlite import insert as lite_insert
from sqlalchemy.exc import SQLAlchemyError
Expand All @@ -25,10 +26,13 @@

from ._db_models import (
DatasetItemDb,
DatasetMetadataBufferDb,
DatasetMetadataDb,
KeyValueStoreMetadataBufferDb,
KeyValueStoreMetadataDb,
KeyValueStoreRecordDb,
RequestDb,
RequestQueueMetadataBufferDb,
RequestQueueMetadataDb,
)
from ._storage_client import SqlStorageClient
Expand All @@ -42,7 +46,6 @@ class MetadataUpdateParams(TypedDict, total=False):

update_accessed_at: NotRequired[bool]
update_modified_at: NotRequired[bool]
force: NotRequired[bool]


class SqlClientMixin(ABC):
Expand All @@ -57,21 +60,24 @@ class SqlClientMixin(ABC):
_METADATA_TABLE: ClassVar[type[DatasetMetadataDb | KeyValueStoreMetadataDb | RequestQueueMetadataDb]]
"""SQLAlchemy model for metadata."""

_BUFFER_TABLE: ClassVar[
type[KeyValueStoreMetadataBufferDb | DatasetMetadataBufferDb | RequestQueueMetadataBufferDb]
]
"""SQLAlchemy model for metadata buffer."""

_ITEM_TABLE: ClassVar[type[DatasetItemDb | KeyValueStoreRecordDb | RequestDb]]
"""SQLAlchemy model for items."""

_CLIENT_TYPE: ClassVar[str]
"""Human-readable client type for error messages."""

_BLOCK_BUFFER_TIME = timedelta(milliseconds=200)
"""Time interval that blocks buffer reading to update metadata."""

def __init__(self, *, id: str, storage_client: SqlStorageClient) -> None:
self._id = id
self._storage_client = storage_client

# Time tracking to reduce database writes during frequent operation
self._accessed_at_allow_update_after: datetime | None = None
self._modified_at_allow_update_after: datetime | None = None
self._accessed_modified_update_interval = storage_client.get_accessed_modified_update_interval()

@classmethod
async def _open(
cls,
Expand Down Expand Up @@ -109,7 +115,9 @@ async def _open(

if orm_metadata:
client = cls(id=orm_metadata.id, storage_client=storage_client)
await client._update_metadata(session, update_accessed_at=True)
await client._add_buffer_record(session)
# Ensure any pending buffer updates are processed
await client._process_buffers()
else:
now = datetime.now(timezone.utc)
metadata = metadata_model(
Expand All @@ -121,8 +129,6 @@ async def _open(
**extra_metadata_fields,
)
client = cls(id=metadata.id, storage_client=storage_client)
client._accessed_at_allow_update_after = now + client._accessed_modified_update_interval
client._modified_at_allow_update_after = now + client._accessed_modified_update_interval
session.add(cls._METADATA_TABLE(**metadata.model_dump(), internal_name=internal_name))

return client
Expand Down Expand Up @@ -262,6 +268,9 @@ async def _purge(self, metadata_kwargs: MetadataUpdateParams) -> None:
Args:
metadata_kwargs: Arguments to pass to _update_metadata.
"""
# Process buffers to ensure metadata is up to date before purging
await self._process_buffers()

stmt = delete(self._ITEM_TABLE).where(self._ITEM_TABLE.storage_id == self._id)
async with self.get_session(with_simple_commit=True) as session:
await session.execute(stmt)
Expand Down Expand Up @@ -290,53 +299,16 @@ async def _get_metadata(
self, metadata_model: type[DatasetMetadata | KeyValueStoreMetadata | RequestQueueMetadata]
) -> DatasetMetadata | KeyValueStoreMetadata | RequestQueueMetadata:
"""Retrieve client metadata."""
# Process any pending buffer updates first
await self._process_buffers()

async with self.get_session() as session:
orm_metadata = await session.get(self._METADATA_TABLE, self._id)
if not orm_metadata:
raise ValueError(f'{self._CLIENT_TYPE} with ID "{self._id}" not found.')

return metadata_model.model_validate(orm_metadata)

def _default_update_metadata(
self, *, update_accessed_at: bool = False, update_modified_at: bool = False, force: bool = False
) -> dict[str, Any]:
"""Prepare common metadata updates with rate limiting.

Args:
update_accessed_at: Whether to update accessed_at timestamp.
update_modified_at: Whether to update modified_at timestamp.
force: Whether to force the update regardless of rate limiting.
"""
values_to_set: dict[str, Any] = {}
now = datetime.now(timezone.utc)

# If the record must be updated (for example, when updating counters), we update timestamps and shift the time.
if force:
if update_modified_at:
values_to_set['modified_at'] = now
self._modified_at_allow_update_after = now + self._accessed_modified_update_interval
if update_accessed_at:
values_to_set['accessed_at'] = now
self._accessed_at_allow_update_after = now + self._accessed_modified_update_interval

elif update_modified_at and (
self._modified_at_allow_update_after is None or now >= self._modified_at_allow_update_after
):
values_to_set['modified_at'] = now
self._modified_at_allow_update_after = now + self._accessed_modified_update_interval
# The record will be updated, we can update `accessed_at` and shift the time.
if update_accessed_at:
values_to_set['accessed_at'] = now
self._accessed_at_allow_update_after = now + self._accessed_modified_update_interval

elif update_accessed_at and (
self._accessed_at_allow_update_after is None or now >= self._accessed_at_allow_update_after
):
values_to_set['accessed_at'] = now
self._accessed_at_allow_update_after = now + self._accessed_modified_update_interval

return values_to_set

@abstractmethod
def _specific_update_metadata(self, **kwargs: Any) -> dict[str, Any]:
"""Prepare storage-specific metadata updates.
Expand All @@ -347,39 +319,188 @@ def _specific_update_metadata(self, **kwargs: Any) -> dict[str, Any]:
**kwargs: Storage-specific update parameters.
"""

@abstractmethod
def _prepare_buffer_data(self, **kwargs: Any) -> dict[str, Any]:
"""Prepare storage-specific buffer data. Must be implemented by concrete classes."""

@abstractmethod
async def _apply_buffer_updates(self, session: AsyncSession, max_buffer_id: int) -> None:
"""Apply aggregated buffer updates to metadata. Must be implemented by concrete classes.

Args:
session: Active database session.
max_buffer_id: Maximum buffer record ID to process.
"""

async def _update_metadata(
self,
session: AsyncSession,
*,
update_accessed_at: bool = False,
update_modified_at: bool = False,
force: bool = False,
**kwargs: Any,
) -> bool:
"""Update storage metadata combining common and specific fields.
) -> None:
"""Directly update storage metadata combining common and specific fields.

Args:
session: Active database session.
update_accessed_at: Whether to update accessed_at timestamp.
update_modified_at: Whether to update modified_at timestamp.
force: Whether to force the update timestamps regardless of rate limiting.
**kwargs: Additional arguments for _specific_update_metadata.

Returns:
True if any updates were made, False otherwise
"""
values_to_set = self._default_update_metadata(
update_accessed_at=update_accessed_at, update_modified_at=update_modified_at, force=force
)
values_to_set: dict[str, Any] = {}
now = datetime.now(timezone.utc)

if update_accessed_at:
values_to_set['accessed_at'] = now

if update_modified_at:
values_to_set['modified_at'] = now

values_to_set.update(self._specific_update_metadata(**kwargs))

if values_to_set:
if (stmt := values_to_set.pop('custom_stmt', None)) is None:
stmt = update(self._METADATA_TABLE).where(self._METADATA_TABLE.id == self._id)
stmt = update(self._METADATA_TABLE).where(self._METADATA_TABLE.id == self._id)

stmt = stmt.values(**values_to_set)
await session.execute(stmt)

async def _add_buffer_record(
self,
session: AsyncSession,
*,
update_modified_at: bool = False,
**kwargs: Any,
) -> None:
"""Add a record to the buffer table and update metadata.

Args:
session: Active database session.
update_modified_at: Whether to update modified_at timestamp.
**kwargs: Additional arguments for _prepare_buffer_data.
"""
now = datetime.now(timezone.utc)
values_to_set = {
'storage_id': self._id,
'accessed_at': now, # All entries in the buffer require updating `accessed_at`
'modified_at': now if update_modified_at else None,
}
values_to_set.update(self._prepare_buffer_data(**kwargs))

session.add(self._BUFFER_TABLE(**values_to_set))

async def _try_acquire_buffer_lock(self, session: AsyncSession) -> bool:
"""Try to acquire buffer processing lock for 200ms.

Args:
session: Active database session.

Returns:
True if lock was acquired, False if already locked by another process.
"""
now = datetime.now(timezone.utc)
lock_until = now + self._BLOCK_BUFFER_TIME
dialect = self._storage_client.get_dialect_name()

if dialect == 'postgresql':
select_stmt = (
select(self._METADATA_TABLE)
.where(
self._METADATA_TABLE.id == self._id,
(self._METADATA_TABLE.buffer_locked_until.is_(None))
| (self._METADATA_TABLE.buffer_locked_until < now),
select(self._BUFFER_TABLE.id).where(self._BUFFER_TABLE.storage_id == self._id).exists(),
)
.with_for_update(skip_locked=True)
)
result = await session.execute(select_stmt)
metadata_row = result.scalar_one_or_none()

if metadata_row is None:
# Either conditions not met OR row is locked by another process
return False

# Acquire lock only if not currently locked or lock has expired
update_stmt = (
update(self._METADATA_TABLE)
.where(
self._METADATA_TABLE.id == self._id,
(self._METADATA_TABLE.buffer_locked_until.is_(None)) | (self._METADATA_TABLE.buffer_locked_until < now),
select(self._BUFFER_TABLE.id).where(self._BUFFER_TABLE.storage_id == self._id).exists(),
)
.values(buffer_locked_until=lock_until)
)

result = await session.execute(update_stmt)
result = cast('CursorResult', result) if not isinstance(result, CursorResult) else result

if result.rowcount > 0:
await session.flush()
return True

return False

async def _release_buffer_lock(self, session: AsyncSession) -> None:
"""Release buffer processing lock by setting buffer_locked_until to NULL.

Args:
session: Active database session.
"""
stmt = update(self._METADATA_TABLE).where(self._METADATA_TABLE.id == self._id).values(buffer_locked_until=None)

await session.execute(stmt)

await session.flush()

async def _has_pending_buffer_updates(self, session: AsyncSession) -> bool:
"""Check if there are pending buffer updates not yet applied to metadata.

Returns False only when buffer_locked_until is NULL (metadata is consistent).

Returns:
True if metadata might be inconsistent due to pending buffer updates.
"""
result = await session.execute(
select(self._METADATA_TABLE.buffer_locked_until).where(self._METADATA_TABLE.id == self._id)
)

locked_until = result.scalar()

# Any non-NULL value means there are pending updates
return locked_until is not None

async def _process_buffers(self) -> None:
"""Process pending buffer updates and apply them to metadata."""
async with self.get_session(with_simple_commit=True) as session:
# Try to acquire buffer processing lock
if not await self._try_acquire_buffer_lock(session):
# Another process is currently processing buffers or lock acquisition failed
return

# Get the maximum buffer ID at this moment
# This creates a consistent snapshot - records added during processing won't be included
max_buffer_id_stmt = select(sql_func.max(self._BUFFER_TABLE.id)).where(
self._BUFFER_TABLE.storage_id == self._id
)

result = await session.execute(max_buffer_id_stmt)
max_buffer_id = result.scalar()

if max_buffer_id is None:
# No buffer records to process. Release the lock and exit.
await self._release_buffer_lock(session)
return

# Apply aggregated buffer updates to metadata using only records <= max_buffer_id
# This method is implemented by concrete storage classes
await self._apply_buffer_updates(session, max_buffer_id=max_buffer_id)

# Clean up only the processed buffer records (those <= max_buffer_id)
delete_stmt = delete(self._BUFFER_TABLE).where(
self._BUFFER_TABLE.storage_id == self._id, self._BUFFER_TABLE.id <= max_buffer_id
)

await session.execute(delete_stmt)

# Release the lock after successful processing
await self._release_buffer_lock(session)
Loading
Loading