Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
247 changes: 184 additions & 63 deletions src/crawlee/storage_clients/_sql/_client_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@

from abc import ABC, abstractmethod
from contextlib import asynccontextmanager
from datetime import datetime, timezone
from datetime import datetime, timedelta, timezone
from logging import getLogger
from typing import TYPE_CHECKING, Any, ClassVar, TypedDict, cast, overload

from sqlalchemy import delete, select, text, update
from sqlalchemy import CursorResult, delete, select, text, update
from sqlalchemy import func as sql_func
from sqlalchemy.dialects.postgresql import insert as pg_insert
from sqlalchemy.dialects.sqlite import insert as lite_insert
from sqlalchemy.exc import SQLAlchemyError
Expand All @@ -25,10 +26,13 @@

from ._db_models import (
DatasetItemDb,
DatasetMetadataBufferDb,
DatasetMetadataDb,
KeyValueStoreMetadataBufferDb,
KeyValueStoreMetadataDb,
KeyValueStoreRecordDb,
RequestDb,
RequestQueueMetadataBufferDb,
RequestQueueMetadataDb,
)
from ._storage_client import SqlStorageClient
Expand All @@ -42,7 +46,6 @@ class MetadataUpdateParams(TypedDict, total=False):

update_accessed_at: NotRequired[bool]
update_modified_at: NotRequired[bool]
force: NotRequired[bool]


class SqlClientMixin(ABC):
Expand All @@ -57,21 +60,24 @@ class SqlClientMixin(ABC):
_METADATA_TABLE: ClassVar[type[DatasetMetadataDb | KeyValueStoreMetadataDb | RequestQueueMetadataDb]]
"""SQLAlchemy model for metadata."""

_BUFFER_TABLE: ClassVar[
type[KeyValueStoreMetadataBufferDb | DatasetMetadataBufferDb | RequestQueueMetadataBufferDb]
]
"""SQLAlchemy model for metadata buffer."""

_ITEM_TABLE: ClassVar[type[DatasetItemDb | KeyValueStoreRecordDb | RequestDb]]
"""SQLAlchemy model for items."""

_CLIENT_TYPE: ClassVar[str]
"""Human-readable client type for error messages."""

_BLOCK_BUFFER_TIME = timedelta(milliseconds=200)
"""Time interval that blocks buffer reading to update metadata."""

def __init__(self, *, id: str, storage_client: SqlStorageClient) -> None:
self._id = id
self._storage_client = storage_client

# Time tracking to reduce database writes during frequent operation
self._accessed_at_allow_update_after: datetime | None = None
self._modified_at_allow_update_after: datetime | None = None
self._accessed_modified_update_interval = storage_client.get_accessed_modified_update_interval()

@classmethod
async def _open(
cls,
Expand Down Expand Up @@ -109,7 +115,9 @@ async def _open(

if orm_metadata:
client = cls(id=orm_metadata.id, storage_client=storage_client)
await client._update_metadata(session, update_accessed_at=True)
await client._add_buffer_record(session)
# Ensure any pending buffer updates are processed
await client._process_buffers()
else:
now = datetime.now(timezone.utc)
metadata = metadata_model(
Expand All @@ -121,8 +129,6 @@ async def _open(
**extra_metadata_fields,
)
client = cls(id=metadata.id, storage_client=storage_client)
client._accessed_at_allow_update_after = now + client._accessed_modified_update_interval
client._modified_at_allow_update_after = now + client._accessed_modified_update_interval
session.add(cls._METADATA_TABLE(**metadata.model_dump(), internal_name=internal_name))

return client
Expand Down Expand Up @@ -262,6 +268,9 @@ async def _purge(self, metadata_kwargs: MetadataUpdateParams) -> None:
Args:
metadata_kwargs: Arguments to pass to _update_metadata.
"""
# Process buffers to ensure metadata is up to date before purging
await self._process_buffers()

stmt = delete(self._ITEM_TABLE).where(self._ITEM_TABLE.storage_id == self._id)
async with self.get_session(with_simple_commit=True) as session:
await session.execute(stmt)
Expand Down Expand Up @@ -290,53 +299,16 @@ async def _get_metadata(
self, metadata_model: type[DatasetMetadata | KeyValueStoreMetadata | RequestQueueMetadata]
) -> DatasetMetadata | KeyValueStoreMetadata | RequestQueueMetadata:
"""Retrieve client metadata."""
# Process any pending buffer updates first
await self._process_buffers()

async with self.get_session() as session:
orm_metadata = await session.get(self._METADATA_TABLE, self._id)
if not orm_metadata:
raise ValueError(f'{self._CLIENT_TYPE} with ID "{self._id}" not found.')

return metadata_model.model_validate(orm_metadata)

def _default_update_metadata(
self, *, update_accessed_at: bool = False, update_modified_at: bool = False, force: bool = False
) -> dict[str, Any]:
"""Prepare common metadata updates with rate limiting.

Args:
update_accessed_at: Whether to update accessed_at timestamp.
update_modified_at: Whether to update modified_at timestamp.
force: Whether to force the update regardless of rate limiting.
"""
values_to_set: dict[str, Any] = {}
now = datetime.now(timezone.utc)

# If the record must be updated (for example, when updating counters), we update timestamps and shift the time.
if force:
if update_modified_at:
values_to_set['modified_at'] = now
self._modified_at_allow_update_after = now + self._accessed_modified_update_interval
if update_accessed_at:
values_to_set['accessed_at'] = now
self._accessed_at_allow_update_after = now + self._accessed_modified_update_interval

elif update_modified_at and (
self._modified_at_allow_update_after is None or now >= self._modified_at_allow_update_after
):
values_to_set['modified_at'] = now
self._modified_at_allow_update_after = now + self._accessed_modified_update_interval
# The record will be updated, we can update `accessed_at` and shift the time.
if update_accessed_at:
values_to_set['accessed_at'] = now
self._accessed_at_allow_update_after = now + self._accessed_modified_update_interval

elif update_accessed_at and (
self._accessed_at_allow_update_after is None or now >= self._accessed_at_allow_update_after
):
values_to_set['accessed_at'] = now
self._accessed_at_allow_update_after = now + self._accessed_modified_update_interval

return values_to_set

@abstractmethod
def _specific_update_metadata(self, **kwargs: Any) -> dict[str, Any]:
"""Prepare storage-specific metadata updates.
Expand All @@ -347,39 +319,188 @@ def _specific_update_metadata(self, **kwargs: Any) -> dict[str, Any]:
**kwargs: Storage-specific update parameters.
"""

@abstractmethod
def _prepare_buffer_data(self, **kwargs: Any) -> dict[str, Any]:
"""Prepare storage-specific buffer data. Must be implemented by concrete classes."""

@abstractmethod
async def _apply_buffer_updates(self, session: AsyncSession, max_buffer_id: int) -> None:
"""Apply aggregated buffer updates to metadata. Must be implemented by concrete classes.

Args:
session: Active database session.
max_buffer_id: Maximum buffer record ID to process.
"""

async def _update_metadata(
self,
session: AsyncSession,
*,
update_accessed_at: bool = False,
update_modified_at: bool = False,
force: bool = False,
**kwargs: Any,
) -> bool:
"""Update storage metadata combining common and specific fields.
) -> None:
"""Directly update storage metadata combining common and specific fields.

Args:
session: Active database session.
update_accessed_at: Whether to update accessed_at timestamp.
update_modified_at: Whether to update modified_at timestamp.
force: Whether to force the update timestamps regardless of rate limiting.
**kwargs: Additional arguments for _specific_update_metadata.

Returns:
True if any updates were made, False otherwise
"""
values_to_set = self._default_update_metadata(
update_accessed_at=update_accessed_at, update_modified_at=update_modified_at, force=force
)
values_to_set: dict[str, Any] = {}
now = datetime.now(timezone.utc)

if update_accessed_at:
values_to_set['accessed_at'] = now

if update_modified_at:
values_to_set['modified_at'] = now

values_to_set.update(self._specific_update_metadata(**kwargs))

if values_to_set:
if (stmt := values_to_set.pop('custom_stmt', None)) is None:
stmt = update(self._METADATA_TABLE).where(self._METADATA_TABLE.id == self._id)
stmt = update(self._METADATA_TABLE).where(self._METADATA_TABLE.id == self._id)

stmt = stmt.values(**values_to_set)
await session.execute(stmt)

async def _add_buffer_record(
self,
session: AsyncSession,
*,
update_modified_at: bool = False,
**kwargs: Any,
) -> None:
"""Add a record to the buffer table and update metadata.

Args:
session: Active database session.
update_modified_at: Whether to update modified_at timestamp.
**kwargs: Additional arguments for _prepare_buffer_data.
"""
now = datetime.now(timezone.utc)
values_to_set = {
'storage_id': self._id,
'accessed_at': now, # All entries in the buffer require updating `accessed_at`
'modified_at': now if update_modified_at else None,
}
values_to_set.update(self._prepare_buffer_data(**kwargs))

session.add(self._BUFFER_TABLE(**values_to_set))

async def _try_acquire_buffer_lock(self, session: AsyncSession) -> bool:
"""Try to acquire buffer processing lock for 200ms.

Args:
session: Active database session.

Returns:
True if lock was acquired, False if already locked by another process.
"""
now = datetime.now(timezone.utc)
lock_until = now + self._BLOCK_BUFFER_TIME
dialect = self._storage_client.get_dialect_name()

if dialect == 'postgresql':
select_stmt = (
select(self._METADATA_TABLE)
.where(
self._METADATA_TABLE.id == self._id,
(self._METADATA_TABLE.buffer_locked_until.is_(None))
| (self._METADATA_TABLE.buffer_locked_until < now),
select(self._BUFFER_TABLE.id).where(self._BUFFER_TABLE.storage_id == self._id).exists(),
)
.with_for_update(skip_locked=True)
)
result = await session.execute(select_stmt)
metadata_row = result.scalar_one_or_none()

if metadata_row is None:
# Either conditions not met OR row is locked by another process
return False

# Acquire lock only if not currently locked or lock has expired
update_stmt = (
update(self._METADATA_TABLE)
.where(
self._METADATA_TABLE.id == self._id,
(self._METADATA_TABLE.buffer_locked_until.is_(None)) | (self._METADATA_TABLE.buffer_locked_until < now),
select(self._BUFFER_TABLE.id).where(self._BUFFER_TABLE.storage_id == self._id).exists(),
)
.values(buffer_locked_until=lock_until)
)

result = await session.execute(update_stmt)
result = cast('CursorResult', result) if not isinstance(result, CursorResult) else result

if result.rowcount > 0:
await session.flush()
return True

return False

async def _release_buffer_lock(self, session: AsyncSession) -> None:
"""Release buffer processing lock by setting buffer_locked_until to NULL.

Args:
session: Active database session.
"""
stmt = update(self._METADATA_TABLE).where(self._METADATA_TABLE.id == self._id).values(buffer_locked_until=None)

await session.execute(stmt)

await session.flush()

async def _has_pending_buffer_updates(self, session: AsyncSession) -> bool:
"""Check if there are pending buffer updates not yet applied to metadata.

Returns False only when buffer_locked_until is NULL (metadata is consistent).

Returns:
True if metadata might be inconsistent due to pending buffer updates.
"""
result = await session.execute(
select(self._METADATA_TABLE.buffer_locked_until).where(self._METADATA_TABLE.id == self._id)
)

locked_until = result.scalar()

# Any non-NULL value means there are pending updates
return locked_until is not None

async def _process_buffers(self) -> None:
"""Process pending buffer updates and apply them to metadata."""
async with self.get_session(with_simple_commit=True) as session:
# Try to acquire buffer processing lock
if not await self._try_acquire_buffer_lock(session):
# Another process is currently processing buffers or lock acquisition failed
return

# Get the maximum buffer ID at this moment
# This creates a consistent snapshot - records added during processing won't be included
max_buffer_id_stmt = select(sql_func.max(self._BUFFER_TABLE.id)).where(
self._BUFFER_TABLE.storage_id == self._id
)

result = await session.execute(max_buffer_id_stmt)
max_buffer_id = result.scalar()

if max_buffer_id is None:
# No buffer records to process. Release the lock and exit.
await self._release_buffer_lock(session)
return

# Apply aggregated buffer updates to metadata using only records <= max_buffer_id
# This method is implemented by concrete storage classes
await self._apply_buffer_updates(session, max_buffer_id=max_buffer_id)

# Clean up only the processed buffer records (those <= max_buffer_id)
delete_stmt = delete(self._BUFFER_TABLE).where(
self._BUFFER_TABLE.storage_id == self._id, self._BUFFER_TABLE.id <= max_buffer_id
)

await session.execute(delete_stmt)

# Release the lock after successful processing
await self._release_buffer_lock(session)
Loading
Loading