diff --git a/docs/guides/storage_clients.mdx b/docs/guides/storage_clients.mdx index 70c4964192..9e777782f4 100644 --- a/docs/guides/storage_clients.mdx +++ b/docs/guides/storage_clients.mdx @@ -208,6 +208,15 @@ class dataset_records { + data } +class dataset_metadata_buffer { + <> + + id (PK) + + accessed_at + + modified_at + + dataset_id (FK) + + delta_item_count +} + %% ======================== %% Key-Value Store Tables %% ======================== @@ -231,15 +240,25 @@ class key_value_store_records { + size } +class key_value_store_metadata_buffer { + <
> + + id (PK) + + accessed_at + + modified_at + + key_value_store_id (FK) +} + %% ======================== %% Client to Table arrows %% ======================== SqlDatasetClient --> datasets SqlDatasetClient --> dataset_records +SqlDatasetClient --> dataset_metadata_buffer SqlKeyValueStoreClient --> key_value_stores SqlKeyValueStoreClient --> key_value_store_records +SqlKeyValueStoreClient --> key_value_store_metadata_buffer ``` ```mermaid --- @@ -294,6 +313,19 @@ class request_queue_state { + forefront_sequence_counter } +class request_queue_metadata_buffer { + <
> + + id (PK) + + accessed_at + + modified_at + + request_queue_id (FK) + + client_id + + delta_handled_count + + delta_pending_count + + delta_total_count + + need_recalc +} + %% ======================== %% Client to Table arrows %% ======================== @@ -301,6 +333,7 @@ class request_queue_state { SqlRequestQueueClient --> request_queues SqlRequestQueueClient --> request_queue_records SqlRequestQueueClient --> request_queue_state +SqlRequestQueueClient --> request_queue_metadata_buffer ``` Configuration options for the `SqlStorageClient` can be set through environment variables or the `Configuration` class: diff --git a/src/crawlee/storage_clients/_sql/_client_mixin.py b/src/crawlee/storage_clients/_sql/_client_mixin.py index c681e3a220..234a0abdad 100644 --- a/src/crawlee/storage_clients/_sql/_client_mixin.py +++ b/src/crawlee/storage_clients/_sql/_client_mixin.py @@ -2,11 +2,12 @@ from abc import ABC, abstractmethod from contextlib import asynccontextmanager -from datetime import datetime, timezone +from datetime import datetime, timedelta, timezone from logging import getLogger from typing import TYPE_CHECKING, Any, ClassVar, TypedDict, cast, overload -from sqlalchemy import delete, select, text, update +from sqlalchemy import CursorResult, delete, select, text, update +from sqlalchemy import func as sql_func from sqlalchemy.dialects.postgresql import insert as pg_insert from sqlalchemy.dialects.sqlite import insert as lite_insert from sqlalchemy.exc import SQLAlchemyError @@ -25,10 +26,13 @@ from ._db_models import ( DatasetItemDb, + DatasetMetadataBufferDb, DatasetMetadataDb, + KeyValueStoreMetadataBufferDb, KeyValueStoreMetadataDb, KeyValueStoreRecordDb, RequestDb, + RequestQueueMetadataBufferDb, RequestQueueMetadataDb, ) from ._storage_client import SqlStorageClient @@ -40,9 +44,8 @@ class MetadataUpdateParams(TypedDict, total=False): """Parameters for updating metadata.""" - update_accessed_at: NotRequired[bool] - update_modified_at: NotRequired[bool] - force: NotRequired[bool] + accessed_at: NotRequired[datetime] + modified_at: NotRequired[datetime] class SqlClientMixin(ABC): @@ -57,21 +60,24 @@ class SqlClientMixin(ABC): _METADATA_TABLE: ClassVar[type[DatasetMetadataDb | KeyValueStoreMetadataDb | RequestQueueMetadataDb]] """SQLAlchemy model for metadata.""" + _BUFFER_TABLE: ClassVar[ + type[KeyValueStoreMetadataBufferDb | DatasetMetadataBufferDb | RequestQueueMetadataBufferDb] + ] + """SQLAlchemy model for metadata buffer.""" + _ITEM_TABLE: ClassVar[type[DatasetItemDb | KeyValueStoreRecordDb | RequestDb]] """SQLAlchemy model for items.""" _CLIENT_TYPE: ClassVar[str] """Human-readable client type for error messages.""" + _BLOCK_BUFFER_TIME = timedelta(seconds=1) + """Time interval that blocks buffer reading to update metadata.""" + def __init__(self, *, id: str, storage_client: SqlStorageClient) -> None: self._id = id self._storage_client = storage_client - # Time tracking to reduce database writes during frequent operation - self._accessed_at_allow_update_after: datetime | None = None - self._modified_at_allow_update_after: datetime | None = None - self._accessed_modified_update_interval = storage_client.get_accessed_modified_update_interval() - @classmethod async def _open( cls, @@ -109,7 +115,9 @@ async def _open( if orm_metadata: client = cls(id=orm_metadata.id, storage_client=storage_client) - await client._update_metadata(session, update_accessed_at=True) + await client._add_buffer_record(session) + # Ensure any pending buffer updates are processed + await client._process_buffers() else: now = datetime.now(timezone.utc) metadata = metadata_model( @@ -121,8 +129,6 @@ async def _open( **extra_metadata_fields, ) client = cls(id=metadata.id, storage_client=storage_client) - client._accessed_at_allow_update_after = now + client._accessed_modified_update_interval - client._modified_at_allow_update_after = now + client._accessed_modified_update_interval session.add(cls._METADATA_TABLE(**metadata.model_dump(), internal_name=internal_name)) return client @@ -262,9 +268,12 @@ async def _purge(self, metadata_kwargs: MetadataUpdateParams) -> None: Args: metadata_kwargs: Arguments to pass to _update_metadata. """ - stmt = delete(self._ITEM_TABLE).where(self._ITEM_TABLE.storage_id == self._id) + # Process buffers to ensure metadata is up to date before purging + await self._process_buffers() + + stmt_records = delete(self._ITEM_TABLE).where(self._ITEM_TABLE.storage_id == self._id) async with self.get_session(with_simple_commit=True) as session: - await session.execute(stmt) + await session.execute(stmt_records) await self._update_metadata(session, **metadata_kwargs) async def _drop(self) -> None: @@ -290,6 +299,9 @@ async def _get_metadata( self, metadata_model: type[DatasetMetadata | KeyValueStoreMetadata | RequestQueueMetadata] ) -> DatasetMetadata | KeyValueStoreMetadata | RequestQueueMetadata: """Retrieve client metadata.""" + # Process any pending buffer updates first + await self._process_buffers() + async with self.get_session() as session: orm_metadata = await session.get(self._METADATA_TABLE, self._id) if not orm_metadata: @@ -297,46 +309,6 @@ async def _get_metadata( return metadata_model.model_validate(orm_metadata) - def _default_update_metadata( - self, *, update_accessed_at: bool = False, update_modified_at: bool = False, force: bool = False - ) -> dict[str, Any]: - """Prepare common metadata updates with rate limiting. - - Args: - update_accessed_at: Whether to update accessed_at timestamp. - update_modified_at: Whether to update modified_at timestamp. - force: Whether to force the update regardless of rate limiting. - """ - values_to_set: dict[str, Any] = {} - now = datetime.now(timezone.utc) - - # If the record must be updated (for example, when updating counters), we update timestamps and shift the time. - if force: - if update_modified_at: - values_to_set['modified_at'] = now - self._modified_at_allow_update_after = now + self._accessed_modified_update_interval - if update_accessed_at: - values_to_set['accessed_at'] = now - self._accessed_at_allow_update_after = now + self._accessed_modified_update_interval - - elif update_modified_at and ( - self._modified_at_allow_update_after is None or now >= self._modified_at_allow_update_after - ): - values_to_set['modified_at'] = now - self._modified_at_allow_update_after = now + self._accessed_modified_update_interval - # The record will be updated, we can update `accessed_at` and shift the time. - if update_accessed_at: - values_to_set['accessed_at'] = now - self._accessed_at_allow_update_after = now + self._accessed_modified_update_interval - - elif update_accessed_at and ( - self._accessed_at_allow_update_after is None or now >= self._accessed_at_allow_update_after - ): - values_to_set['accessed_at'] = now - self._accessed_at_allow_update_after = now + self._accessed_modified_update_interval - - return values_to_set - @abstractmethod def _specific_update_metadata(self, **kwargs: Any) -> dict[str, Any]: """Prepare storage-specific metadata updates. @@ -347,30 +319,42 @@ def _specific_update_metadata(self, **kwargs: Any) -> dict[str, Any]: **kwargs: Storage-specific update parameters. """ + @abstractmethod + def _prepare_buffer_data(self, **kwargs: Any) -> dict[str, Any]: + """Prepare storage-specific buffer data. Must be implemented by concrete classes.""" + + @abstractmethod + async def _apply_buffer_updates(self, session: AsyncSession, max_buffer_id: int) -> None: + """Apply aggregated buffer updates to metadata. Must be implemented by concrete classes. + + Args: + session: Active database session. + max_buffer_id: Maximum buffer record ID to process. + """ + async def _update_metadata( self, session: AsyncSession, *, - update_accessed_at: bool = False, - update_modified_at: bool = False, - force: bool = False, + accessed_at: datetime | None = None, + modified_at: datetime | None = None, **kwargs: Any, - ) -> bool: - """Update storage metadata combining common and specific fields. + ) -> None: + """Directly update storage metadata combining common and specific fields. Args: session: Active database session. - update_accessed_at: Whether to update accessed_at timestamp. - update_modified_at: Whether to update modified_at timestamp. - force: Whether to force the update timestamps regardless of rate limiting. + accessed_at: Datetime to set as accessed_at timestamp. + modified_at: Datetime to set as modified_at timestamp. **kwargs: Additional arguments for _specific_update_metadata. - - Returns: - True if any updates were made, False otherwise """ - values_to_set = self._default_update_metadata( - update_accessed_at=update_accessed_at, update_modified_at=update_modified_at, force=force - ) + values_to_set: dict[str, Any] = {} + + if accessed_at is not None: + values_to_set['accessed_at'] = accessed_at + + if modified_at is not None: + values_to_set['modified_at'] = modified_at values_to_set.update(self._specific_update_metadata(**kwargs)) @@ -380,6 +364,141 @@ async def _update_metadata( stmt = stmt.values(**values_to_set) await session.execute(stmt) + + async def _add_buffer_record( + self, + session: AsyncSession, + *, + update_modified_at: bool = False, + **kwargs: Any, + ) -> None: + """Add a record to the buffer table and update metadata. + + Args: + session: Active database session. + update_modified_at: Whether to update modified_at timestamp. + **kwargs: Additional arguments for _prepare_buffer_data. + """ + now = datetime.now(timezone.utc) + values_to_set = { + 'storage_id': self._id, + 'accessed_at': now, # All entries in the buffer require updating `accessed_at` + 'modified_at': now if update_modified_at else None, + } + values_to_set.update(self._prepare_buffer_data(**kwargs)) + + session.add(self._BUFFER_TABLE(**values_to_set)) + + async def _try_acquire_buffer_lock(self, session: AsyncSession) -> bool: + """Try to acquire buffer processing lock for 200ms. + + Args: + session: Active database session. + + Returns: + True if lock was acquired, False if already locked by another process. + """ + now = datetime.now(timezone.utc) + lock_until = now + self._BLOCK_BUFFER_TIME + dialect = self._storage_client.get_dialect_name() + + if dialect == 'postgresql': + select_stmt = ( + select(self._METADATA_TABLE) + .where( + self._METADATA_TABLE.id == self._id, + (self._METADATA_TABLE.buffer_locked_until.is_(None)) + | (self._METADATA_TABLE.buffer_locked_until < now), + select(self._BUFFER_TABLE.id).where(self._BUFFER_TABLE.storage_id == self._id).exists(), + ) + .with_for_update(skip_locked=True) + ) + result = await session.execute(select_stmt) + metadata_row = result.scalar_one_or_none() + + if metadata_row is None: + # Either conditions not met OR row is locked by another process + return False + + # Acquire lock only if not currently locked or lock has expired + update_stmt = ( + update(self._METADATA_TABLE) + .where( + self._METADATA_TABLE.id == self._id, + (self._METADATA_TABLE.buffer_locked_until.is_(None)) | (self._METADATA_TABLE.buffer_locked_until < now), + select(self._BUFFER_TABLE.id).where(self._BUFFER_TABLE.storage_id == self._id).exists(), + ) + .values(buffer_locked_until=lock_until) + ) + + result = await session.execute(update_stmt) + result = cast('CursorResult', result) if not isinstance(result, CursorResult) else result + + if result.rowcount > 0: + await session.flush() return True return False + + async def _release_buffer_lock(self, session: AsyncSession) -> None: + """Release buffer processing lock by setting buffer_locked_until to NULL. + + Args: + session: Active database session. + """ + stmt = update(self._METADATA_TABLE).where(self._METADATA_TABLE.id == self._id).values(buffer_locked_until=None) + + await session.execute(stmt) + + async def _has_pending_buffer_updates(self, session: AsyncSession) -> bool: + """Check if there are pending buffer updates not yet applied to metadata. + + Returns False only when buffer_locked_until is NULL (metadata is consistent). + + Returns: + True if metadata might be inconsistent due to pending buffer updates. + """ + result = await session.execute( + select(self._METADATA_TABLE.buffer_locked_until).where(self._METADATA_TABLE.id == self._id) + ) + + locked_until = result.scalar() + + # Any non-NULL value means there are pending updates + return locked_until is not None + + async def _process_buffers(self) -> None: + """Process pending buffer updates and apply them to metadata.""" + async with self.get_session(with_simple_commit=True) as session: + # Try to acquire buffer processing lock + if not await self._try_acquire_buffer_lock(session): + # Another process is currently processing buffers or lock acquisition failed + return + + # Get the maximum buffer ID at this moment + # This creates a consistent snapshot - records added during processing won't be included + max_buffer_id_stmt = select(sql_func.max(self._BUFFER_TABLE.id)).where( + self._BUFFER_TABLE.storage_id == self._id + ) + + result = await session.execute(max_buffer_id_stmt) + max_buffer_id = result.scalar() + + if max_buffer_id is None: + # No buffer records to process. Release the lock and exit. + await self._release_buffer_lock(session) + return + + # Apply aggregated buffer updates to metadata using only records <= max_buffer_id + # This method is implemented by concrete storage classes + await self._apply_buffer_updates(session, max_buffer_id=max_buffer_id) + + # Clean up only the processed buffer records (those <= max_buffer_id) + delete_stmt = delete(self._BUFFER_TABLE).where( + self._BUFFER_TABLE.storage_id == self._id, self._BUFFER_TABLE.id <= max_buffer_id + ) + + await session.execute(delete_stmt) + + # Release the lock after successful processing + await self._release_buffer_lock(session) diff --git a/src/crawlee/storage_clients/_sql/_dataset_client.py b/src/crawlee/storage_clients/_sql/_dataset_client.py index 61873975c8..1ea9d6b7cb 100644 --- a/src/crawlee/storage_clients/_sql/_dataset_client.py +++ b/src/crawlee/storage_clients/_sql/_dataset_client.py @@ -1,21 +1,24 @@ from __future__ import annotations +from datetime import datetime, timezone from logging import getLogger from typing import TYPE_CHECKING, Any from sqlalchemy import Select, insert, select +from sqlalchemy import func as sql_func from typing_extensions import Self, override from crawlee.storage_clients._base import DatasetClient from crawlee.storage_clients.models import DatasetItemsListPage, DatasetMetadata from ._client_mixin import MetadataUpdateParams, SqlClientMixin -from ._db_models import DatasetItemDb, DatasetMetadataDb +from ._db_models import DatasetItemDb, DatasetMetadataBufferDb, DatasetMetadataDb if TYPE_CHECKING: from collections.abc import AsyncIterator from sqlalchemy import Select + from sqlalchemy.ext.asyncio import AsyncSession from typing_extensions import NotRequired from ._storage_client import SqlStorageClient @@ -40,6 +43,7 @@ class SqlDatasetClient(DatasetClient, SqlClientMixin): The dataset data is stored in SQL database tables following the pattern: - `datasets` table: Contains dataset metadata (id, name, timestamps, item_count) - `dataset_records` table: Contains individual items with JSON data and auto-increment ordering + - `dataset_metadata_buffer` table: Buffers metadata updates for performance optimization Items are stored as a JSON object in SQLite and as JSONB in PostgreSQL. These objects must be JSON-serializable. The `item_id` auto-increment primary key ensures insertion order is preserved. @@ -58,6 +62,9 @@ class SqlDatasetClient(DatasetClient, SqlClientMixin): _CLIENT_TYPE = 'Dataset' """Human-readable client type for error messages.""" + _BUFFER_TABLE = DatasetMetadataBufferDb + """SQLAlchemy model for metadata buffer.""" + def __init__( self, *, @@ -121,12 +128,12 @@ async def purge(self) -> None: Resets item_count to 0 and deletes all records from dataset_records table. """ + now = datetime.now(timezone.utc) await self._purge( metadata_kwargs=_DatasetMetadataUpdateParams( new_item_count=0, - update_accessed_at=True, - update_modified_at=True, - force=True, + accessed_at=now, + modified_at=now, ) ) @@ -135,23 +142,13 @@ async def push_data(self, data: list[dict[str, Any]] | dict[str, Any]) -> None: if not isinstance(data, list): data = [data] - db_items: list[dict[str, Any]] = [] db_items = [{'dataset_id': self._id, 'data': item} for item in data] stmt = insert(self._ITEM_TABLE).values(db_items) async with self.get_session(with_simple_commit=True) as session: await session.execute(stmt) - await self._update_metadata( - session, - **_DatasetMetadataUpdateParams( - update_accessed_at=True, - update_modified_at=True, - delta_item_count=len(data), - new_item_count=len(data), - force=True, - ), - ) + await self._add_buffer_record(session, update_modified_at=True, delta_item_count=len(data)) @override async def get_data( @@ -183,15 +180,11 @@ async def get_data( view=view, ) - async with self.get_session() as session: + async with self.get_session(with_simple_commit=True) as session: result = await session.execute(stmt) db_items = result.scalars().all() - updated = await self._update_metadata(session, **_DatasetMetadataUpdateParams(update_accessed_at=True)) - - # Commit updates to the metadata - if updated: - await session.commit() + await self._add_buffer_record(session) items = [db_item.data for db_item in db_items] metadata = await self.get_metadata() @@ -230,17 +223,13 @@ async def iterate_items( skip_hidden=skip_hidden, ) - async with self.get_session() as session: + async with self.get_session(with_simple_commit=True) as session: db_items = await session.stream_scalars(stmt) async for db_item in db_items: yield db_item.data - updated = await self._update_metadata(session, **_DatasetMetadataUpdateParams(update_accessed_at=True)) - - # Commit updates to the metadata - if updated: - await session.commit() + await self._add_buffer_record(session) def _prepare_get_stmt( self, @@ -286,13 +275,14 @@ def _prepare_get_stmt( return stmt.offset(offset).limit(limit) + @override def _specific_update_metadata( self, new_item_count: int | None = None, delta_item_count: int | None = None, **_kwargs: dict[str, Any], ) -> dict[str, Any]: - """Update the dataset metadata in the database. + """Directly update the dataset metadata in the database. Args: session: The SQLAlchemy AsyncSession to use for the update. @@ -308,3 +298,39 @@ def _specific_update_metadata( values_to_set['item_count'] = self._METADATA_TABLE.item_count + delta_item_count return values_to_set + + @override + def _prepare_buffer_data(self, delta_item_count: int | None = None, **_kwargs: Any) -> dict[str, Any]: + """Prepare dataset specific buffer data. + + Args: + delta_item_count: If provided, add this value to the current item count. + """ + buffer_data = {} + if delta_item_count is not None: + buffer_data['delta_item_count'] = delta_item_count + + return buffer_data + + @override + async def _apply_buffer_updates(self, session: AsyncSession, max_buffer_id: int) -> None: + aggregation_stmt = select( + sql_func.max(self._BUFFER_TABLE.accessed_at).label('max_accessed_at'), + sql_func.max(self._BUFFER_TABLE.modified_at).label('max_modified_at'), + sql_func.sum(self._BUFFER_TABLE.delta_item_count).label('delta_item_count'), + ).where(self._BUFFER_TABLE.storage_id == self._id, self._BUFFER_TABLE.id <= max_buffer_id) + + result = await session.execute(aggregation_stmt) + row = result.first() + + if not row: + return + + await self._update_metadata( + session, + **_DatasetMetadataUpdateParams( + accessed_at=row.max_accessed_at, + modified_at=row.max_modified_at, + delta_item_count=row.delta_item_count, + ), + ) diff --git a/src/crawlee/storage_clients/_sql/_db_models.py b/src/crawlee/storage_clients/_sql/_db_models.py index 2a8f8b565b..00f38f200b 100644 --- a/src/crawlee/storage_clients/_sql/_db_models.py +++ b/src/crawlee/storage_clients/_sql/_db_models.py @@ -52,10 +52,10 @@ class Base(DeclarativeBase): class StorageMetadataDb: """Base database model for storage metadata.""" - internal_name: Mapped[str] = mapped_column(String, nullable=False, index=True, unique=True) + internal_name: Mapped[str] = mapped_column(String(255), nullable=False, index=True, unique=True) """Internal unique name for a storage instance based on a name or alias.""" - name: Mapped[str | None] = mapped_column(String, nullable=True, unique=True) + name: Mapped[str | None] = mapped_column(String(255), nullable=True, unique=True) """Human-readable name. None becomes 'default' in database to enforce uniqueness.""" accessed_at: Mapped[datetime] = mapped_column(AwareDateTime, nullable=False) @@ -67,6 +67,9 @@ class StorageMetadataDb: modified_at: Mapped[datetime] = mapped_column(AwareDateTime, nullable=False) """Last modification datetime.""" + buffer_locked_until: Mapped[datetime | None] = mapped_column(AwareDateTime, nullable=True) + """Timestamp until which buffer processing is locked for this storage. NULL = unlocked.""" + class DatasetMetadataDb(StorageMetadataDb, Base): """Metadata table for datasets.""" @@ -84,6 +87,11 @@ class DatasetMetadataDb(StorageMetadataDb, Base): back_populates='dataset', cascade='all, delete-orphan', lazy='noload' ) + # Relationship to metadata buffer updates + buffer: Mapped[list[DatasetMetadataBufferDb]] = relationship( + back_populates='dataset', cascade='all, delete-orphan', lazy='noload' + ) + id = synonym('dataset_id') """Alias for dataset_id to match Pydantic expectations.""" @@ -92,6 +100,13 @@ class RequestQueueMetadataDb(StorageMetadataDb, Base): """Metadata table for request queues.""" __tablename__ = 'request_queues' + __table_args__ = ( + Index( + 'idx_buffer_lock', + 'request_queue_id', + 'buffer_locked_until', + ), + ) request_queue_id: Mapped[str] = mapped_column(String(20), nullable=False, primary_key=True) """Unique identifier for the request queue.""" @@ -117,6 +132,11 @@ class RequestQueueMetadataDb(StorageMetadataDb, Base): back_populates='queue', cascade='all, delete-orphan', lazy='noload' ) + # Relationship to metadata buffer updates + buffer: Mapped[list[RequestQueueMetadataBufferDb]] = relationship( + back_populates='queue', cascade='all, delete-orphan', lazy='noload' + ) + id = synonym('request_queue_id') """Alias for request_queue_id to match Pydantic expectations.""" @@ -134,6 +154,11 @@ class KeyValueStoreMetadataDb(StorageMetadataDb, Base): back_populates='kvs', cascade='all, delete-orphan', lazy='noload' ) + # Relationship to metadata buffer updates + buffer: Mapped[list[KeyValueStoreMetadataBufferDb]] = relationship( + back_populates='kvs', cascade='all, delete-orphan', lazy='noload' + ) + id = synonym('key_value_store_id') """Alias for key_value_store_id to match Pydantic expectations.""" @@ -206,7 +231,12 @@ class RequestDb(Base): 'request_queue_id', 'is_handled', 'sequence_number', - postgresql_where=text('is_handled is false'), + postgresql_where=text('is_handled = false'), + ), + Index( + 'idx_count_aggregate', + 'request_queue_id', + 'is_handled', ), ) @@ -218,7 +248,7 @@ class RequestDb(Base): ) """Foreign key to metadata request queue record.""" - data: Mapped[str] = mapped_column(String, nullable=False) + data: Mapped[str] = mapped_column(String(5000), nullable=False) """JSON-serialized Request object.""" sequence_number: Mapped[int] = mapped_column(Integer, nullable=False) @@ -266,3 +296,90 @@ class VersionDb(Base): __tablename__ = 'version' version: Mapped[str] = mapped_column(String(10), nullable=False, primary_key=True) + + +class MetadataBufferDb: + """Base model for metadata update buffer tables.""" + + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + """Auto-increment primary key for ordering.""" + + # Timestamp fields - use max value when aggregating + accessed_at: Mapped[datetime | None] = mapped_column(AwareDateTime, nullable=False) + """New accessed_at timestamp, if being updated.""" + + modified_at: Mapped[datetime | None] = mapped_column(AwareDateTime, nullable=True) + """New modified_at timestamp, if being updated.""" + + +class KeyValueStoreMetadataBufferDb(MetadataBufferDb, Base): + """Buffer table for deferred key-value store metadata updates to reduce concurrent access issues.""" + + __tablename__ = 'key_value_store_metadata_buffer' + + key_value_store_id: Mapped[str] = mapped_column( + String(20), ForeignKey('key_value_stores.key_value_store_id', ondelete='CASCADE'), nullable=False, index=True + ) + """ID of the key-value store being updated.""" + + # Relationship back to key-value store metadata + kvs: Mapped[KeyValueStoreMetadataDb] = relationship(back_populates='buffer') + + storage_id = synonym('key_value_store_id') + """Alias for key_value_store_id to match SqlClientMixin expectations.""" + + +class DatasetMetadataBufferDb(MetadataBufferDb, Base): + """Buffer table for deferred dataset metadata updates to reduce concurrent access issues.""" + + __tablename__ = 'dataset_metadata_buffer' + + dataset_id: Mapped[str] = mapped_column( + String(20), ForeignKey('datasets.dataset_id', ondelete='CASCADE'), nullable=False, index=True + ) + """ID of the dataset being updated.""" + + # Counter deltas - use SUM when aggregating + delta_item_count: Mapped[int | None] = mapped_column(Integer, nullable=True) + """Delta for dataset item_count.""" + + # Relationship back to dataset metadata + dataset: Mapped[DatasetMetadataDb] = relationship(back_populates='buffer') + + storage_id = synonym('dataset_id') + """Alias for dataset_id to match SqlClientMixin expectations.""" + + +class RequestQueueMetadataBufferDb(MetadataBufferDb, Base): + """Buffer table for deferred request queue metadata updates to reduce concurrent access issues.""" + + __tablename__ = 'request_queue_metadata_buffer' + + __table_args__ = (Index('idx_rq_client', 'request_queue_id', 'client_id'),) + + request_queue_id: Mapped[str] = mapped_column( + String(20), ForeignKey('request_queues.request_queue_id', ondelete='CASCADE'), nullable=False, index=True + ) + """ID of the request queue being updated.""" + + client_id: Mapped[str] = mapped_column(String(32), nullable=False) + """Identifier of the client making this update.""" + + # Counter deltas - use SUM when aggregating + delta_handled_count: Mapped[int | None] = mapped_column(Integer, nullable=True) + """Delta for handled_request_count.""" + + delta_pending_count: Mapped[int | None] = mapped_column(Integer, nullable=True) + """Delta for pending_request_count.""" + + delta_total_count: Mapped[int | None] = mapped_column(Integer, nullable=True) + """Delta for total_request_count.""" + + need_recalc: Mapped[bool | None] = mapped_column(Boolean, nullable=False, default=False) + """Flag indicating that counters need recalculation from actual data.""" + + # Relationship back to request queue metadata + queue: Mapped[RequestQueueMetadataDb] = relationship(back_populates='buffer') + + storage_id = synonym('request_queue_id') + """Alias for request_queue_id to match SqlClientMixin expectations.""" diff --git a/src/crawlee/storage_clients/_sql/_key_value_store_client.py b/src/crawlee/storage_clients/_sql/_key_value_store_client.py index dfa02d8014..3a06a98cc2 100644 --- a/src/crawlee/storage_clients/_sql/_key_value_store_client.py +++ b/src/crawlee/storage_clients/_sql/_key_value_store_client.py @@ -1,22 +1,30 @@ from __future__ import annotations import json +from datetime import datetime, timezone from logging import getLogger from typing import TYPE_CHECKING, Any, cast from sqlalchemy import CursorResult, delete, select +from sqlalchemy import func as sql_func from typing_extensions import Self, override from crawlee._utils.file import infer_mime_type from crawlee.storage_clients._base import KeyValueStoreClient -from crawlee.storage_clients.models import KeyValueStoreMetadata, KeyValueStoreRecord, KeyValueStoreRecordMetadata +from crawlee.storage_clients.models import ( + KeyValueStoreMetadata, + KeyValueStoreRecord, + KeyValueStoreRecordMetadata, +) from ._client_mixin import MetadataUpdateParams, SqlClientMixin -from ._db_models import KeyValueStoreMetadataDb, KeyValueStoreRecordDb +from ._db_models import KeyValueStoreMetadataBufferDb, KeyValueStoreMetadataDb, KeyValueStoreRecordDb if TYPE_CHECKING: from collections.abc import AsyncIterator + from sqlalchemy.ext.asyncio import AsyncSession + from ._storage_client import SqlStorageClient @@ -34,6 +42,7 @@ class SqlKeyValueStoreClient(KeyValueStoreClient, SqlClientMixin): - `key_value_stores` table: Contains store metadata (id, name, timestamps) - `key_value_store_records` table: Contains individual key-value pairs with binary value storage, content type, and size information + - `key_value_store_metadata_buffer` table: Buffers metadata updates for performance optimization Values are serialized based on their type: JSON objects are stored as formatted JSON, text values as UTF-8 encoded strings, and binary data as-is in the `LargeBinary` column. @@ -57,6 +66,9 @@ class SqlKeyValueStoreClient(KeyValueStoreClient, SqlClientMixin): _CLIENT_TYPE = 'Key-value store' """Human-readable client type for error messages.""" + _BUFFER_TABLE = KeyValueStoreMetadataBufferDb + """SQLAlchemy model for metadata buffer.""" + def __init__( self, *, @@ -124,7 +136,8 @@ async def purge(self) -> None: Remove all records from key_value_store_records table. """ - await self._purge(metadata_kwargs=MetadataUpdateParams(update_accessed_at=True, update_modified_at=True)) + now = datetime.now(timezone.utc) + await self._purge(metadata_kwargs=MetadataUpdateParams(accessed_at=now, modified_at=now)) @override async def set_value(self, *, key: str, value: Any, content_type: str | None = None) -> None: @@ -165,9 +178,7 @@ async def set_value(self, *, key: str, value: Any, content_type: str | None = No async with self.get_session(with_simple_commit=True) as session: await session.execute(upsert_stmt) - await self._update_metadata( - session, **MetadataUpdateParams(update_accessed_at=True, update_modified_at=True) - ) + await self._add_buffer_record(session, update_modified_at=True) @override async def get_value(self, *, key: str) -> KeyValueStoreRecord | None: @@ -175,15 +186,11 @@ async def get_value(self, *, key: str) -> KeyValueStoreRecord | None: stmt = select(self._ITEM_TABLE).where( self._ITEM_TABLE.key_value_store_id == self._id, self._ITEM_TABLE.key == key ) - async with self.get_session() as session: + async with self.get_session(with_simple_commit=True) as session: result = await session.execute(stmt) record_db = result.scalar_one_or_none() - updated = await self._update_metadata(session, **MetadataUpdateParams(update_accessed_at=True)) - - # Commit updates to the metadata - if updated: - await session.commit() + await self._add_buffer_record(session) if not record_db: return None @@ -231,11 +238,7 @@ async def delete_value(self, *, key: str) -> None: # Update metadata if we actually deleted something if result.rowcount > 0: - await self._update_metadata( - session, **MetadataUpdateParams(update_accessed_at=True, update_modified_at=True) - ) - - await session.commit() + await self._add_buffer_record(session, update_accessed_at=True, update_modified_at=True) @override async def iterate_keys( @@ -259,7 +262,7 @@ async def iterate_keys( if limit is not None: stmt = stmt.limit(limit) - async with self.get_session() as session: + async with self.get_session(with_simple_commit=True) as session: result = await session.stream(stmt.execution_options(stream_results=True)) async for row in result: @@ -269,26 +272,18 @@ async def iterate_keys( size=row.size, ) - updated = await self._update_metadata(session, **MetadataUpdateParams(update_accessed_at=True)) - - # Commit updates to the metadata - if updated: - await session.commit() + await self._add_buffer_record(session) @override async def record_exists(self, *, key: str) -> bool: stmt = select(self._ITEM_TABLE.key).where( self._ITEM_TABLE.key_value_store_id == self._id, self._ITEM_TABLE.key == key ) - async with self.get_session() as session: + async with self.get_session(with_simple_commit=True) as session: # Check if record exists result = await session.execute(stmt) - updated = await self._update_metadata(session, **MetadataUpdateParams(update_accessed_at=True)) - - # Commit updates to the metadata - if updated: - await session.commit() + await self._add_buffer_record(session) return result.scalar_one_or_none() is not None @@ -296,5 +291,36 @@ async def record_exists(self, *, key: str) -> bool: async def get_public_url(self, *, key: str) -> str: raise NotImplementedError('Public URLs are not supported for SQL key-value stores.') + @override def _specific_update_metadata(self, **_kwargs: dict[str, Any]) -> dict[str, Any]: return {} + + @override + def _prepare_buffer_data(self, **_kwargs: Any) -> dict[str, Any]: + """Prepare key-value store specific buffer data. + + For KeyValueStore, we don't have specific metadata fields to track in buffer, + so we just return empty dict. The base buffer will handle accessed_at/modified_at. + """ + return {} + + @override + async def _apply_buffer_updates(self, session: AsyncSession, max_buffer_id: int) -> None: + aggregation_stmt = select( + sql_func.max(self._BUFFER_TABLE.accessed_at).label('max_accessed_at'), + sql_func.max(self._BUFFER_TABLE.modified_at).label('max_modified_at'), + ).where(self._BUFFER_TABLE.storage_id == self._id, self._BUFFER_TABLE.id <= max_buffer_id) + + result = await session.execute(aggregation_stmt) + row = result.first() + + if not row: + return + + await self._update_metadata( + session, + **MetadataUpdateParams( + accessed_at=row.max_accessed_at, + modified_at=row.max_modified_at, + ), + ) diff --git a/src/crawlee/storage_clients/_sql/_request_queue_client.py b/src/crawlee/storage_clients/_sql/_request_queue_client.py index f5a320bb21..cba4f18702 100644 --- a/src/crawlee/storage_clients/_sql/_request_queue_client.py +++ b/src/crawlee/storage_clients/_sql/_request_queue_client.py @@ -7,7 +7,8 @@ from logging import getLogger from typing import TYPE_CHECKING, Any, cast -from sqlalchemy import CursorResult, func, or_, select, update +from sqlalchemy import CursorResult, exists, func, or_, select, update +from sqlalchemy import func as sql_func from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.orm import load_only from typing_extensions import NotRequired, Self, override @@ -23,12 +24,13 @@ ) from ._client_mixin import MetadataUpdateParams, SqlClientMixin -from ._db_models import RequestDb, RequestQueueMetadataDb, RequestQueueStateDb +from ._db_models import RequestDb, RequestQueueMetadataBufferDb, RequestQueueMetadataDb, RequestQueueStateDb if TYPE_CHECKING: from collections.abc import Sequence from sqlalchemy.ext.asyncio import AsyncSession + from sqlalchemy.sql import ColumnElement from ._storage_client import SqlStorageClient @@ -44,6 +46,7 @@ class _QueueMetadataUpdateParams(MetadataUpdateParams): new_total_request_count: NotRequired[int] delta_handled_request_count: NotRequired[int] delta_pending_request_count: NotRequired[int] + delta_total_request_count: NotRequired[int] recalculate: NotRequired[bool] update_had_multiple_clients: NotRequired[bool] @@ -64,6 +67,7 @@ class SqlRequestQueueClient(RequestQueueClient, SqlClientMixin): - `request_queue_records` table: Contains individual requests with JSON data, unique keys for deduplication, sequence numbers for ordering, and processing status flags - `request_queue_state` table: Maintains counters for sequence numbers to ensure proper ordering of requests. + - `request_queue_metadata_buffer` table: Buffers metadata updates for performance optimization Requests are serialized to JSON for storage and maintain proper ordering through sequence numbers. The implementation provides concurrent access safety through transaction @@ -93,6 +97,9 @@ class SqlRequestQueueClient(RequestQueueClient, SqlClientMixin): """Number of seconds for which a request is considered blocked in the database after being fetched for processing. """ + _BUFFER_TABLE = RequestQueueMetadataBufferDb + """SQLAlchemy model for metadata buffer.""" + def __init__( self, *, @@ -111,6 +118,9 @@ def __init__( self.client_key = crypto_random_object_id(length=32)[:32] """Unique identifier for this client instance.""" + self._had_multiple_clients = False + """Indicates whether the queue has been accessed by multiple clients.""" + @classmethod async def open( cls, @@ -155,7 +165,9 @@ async def open( @override async def get_metadata(self) -> RequestQueueMetadata: # The database is a single place of truth - return await self._get_metadata(RequestQueueMetadata) + metadata = await self._get_metadata(RequestQueueMetadata) + self._had_multiple_clients = metadata.had_multiple_clients + return metadata @override async def drop(self) -> None: @@ -174,12 +186,12 @@ async def purge(self) -> None: Resets pending_request_count and handled_request_count to 0 and deletes all records from request_queue_records table. """ + now = datetime.now(timezone.utc) await self._purge( metadata_kwargs=_QueueMetadataUpdateParams( - update_accessed_at=True, - update_modified_at=True, + accessed_at=now, + modified_at=now, new_pending_request_count=0, - force=True, ) ) @@ -202,7 +214,7 @@ async def add_batch_of_requests( transaction_processed_requests = [] transaction_processed_requests_unique_keys = set() - metadata_recalculate = False + approximate_new_request = 0 # Deduplicate requests by unique_key upfront unique_requests = {} @@ -254,7 +266,6 @@ async def add_batch_of_requests( state.sequence_counter += 1 insert_values.append(value) - metadata_recalculate = True transaction_processed_requests.append( ProcessedRequest( unique_key=request.unique_key, @@ -332,20 +343,20 @@ async def add_batch_of_requests( update_columns=['sequence_number'], conflict_cols=['request_id', 'request_queue_id'], ) - await session.execute(upsert_stmt) + result = await session.execute(upsert_stmt) else: # If the request already exists in the database, we ignore this request when inserting. insert_stmt_with_ignore = self._build_insert_stmt_with_ignore(self._ITEM_TABLE, insert_values) - await session.execute(insert_stmt_with_ignore) + result = await session.execute(insert_stmt_with_ignore) + + result = cast('CursorResult', result) if not isinstance(result, CursorResult) else result + approximate_new_request += result.rowcount - await self._update_metadata( + await self._add_buffer_record( session, - **_QueueMetadataUpdateParams( - recalculate=metadata_recalculate, - update_modified_at=True, - update_accessed_at=True, - force=metadata_recalculate, - ), + update_modified_at=True, + delta_pending_request_count=approximate_new_request, + delta_total_request_count=approximate_new_request, ) try: @@ -354,8 +365,10 @@ async def add_batch_of_requests( except SQLAlchemyError as e: await session.rollback() logger.warning(f'Failed to commit session: {e}') - await self._update_metadata( - session, recalculate=True, update_modified_at=True, update_accessed_at=True, force=True + await self._add_buffer_record( + session, + update_modified_at=True, + reclaculate=True, ) await session.commit() transaction_processed_requests.clear() @@ -383,7 +396,7 @@ async def get_request(self, unique_key: str) -> Request | None: stmt = select(self._ITEM_TABLE).where( self._ITEM_TABLE.request_queue_id == self._id, self._ITEM_TABLE.request_id == request_id ) - async with self.get_session() as session: + async with self.get_session(with_simple_commit=True) as session: result = await session.execute(stmt) request_db = result.scalar_one_or_none() @@ -391,11 +404,7 @@ async def get_request(self, unique_key: str) -> Request | None: logger.warning(f'Request with ID "{unique_key}" not found in the queue.') return None - updated = await self._update_metadata(session, update_accessed_at=True) - - # Commit updates to the metadata - if updated: - await session.commit() + await self._add_buffer_record(session) return Request.model_validate_json(request_db.data) @@ -413,14 +422,14 @@ async def fetch_next_request(self) -> Request | None: select(self._ITEM_TABLE) .where( self._ITEM_TABLE.request_queue_id == self._id, - self._ITEM_TABLE.is_handled.is_(False), + self._ITEM_TABLE.is_handled == False, # noqa: E712 or_(self._ITEM_TABLE.time_blocked_until.is_(None), self._ITEM_TABLE.time_blocked_until < now), ) .order_by(self._ITEM_TABLE.sequence_number.asc()) .limit(self._MAX_BATCH_FETCH_SIZE) ) - async with self.get_session() as session: + async with self.get_session(with_simple_commit=True) as session: # We use the `skip_locked` database mechanism to prevent the 'interception' of requests by another client if dialect == 'postgresql': stmt = stmt.with_for_update(skip_locked=True) @@ -456,7 +465,7 @@ async def fetch_next_request(self) -> Request | None: .where( self._ITEM_TABLE.request_queue_id == self._id, self._ITEM_TABLE.request_id.in_(request_ids), - self._ITEM_TABLE.is_handled.is_(False), + self._ITEM_TABLE.is_handled == False, # noqa: E712 or_(self._ITEM_TABLE.time_blocked_until.is_(None), self._ITEM_TABLE.time_blocked_until < now), ) .values(time_blocked_until=block_until, client_key=self.client_key) @@ -470,9 +479,7 @@ async def fetch_next_request(self) -> Request | None: await session.rollback() return None - await self._update_metadata(session, **_QueueMetadataUpdateParams(update_accessed_at=True)) - - await session.commit() + await self._add_buffer_record(session) requests = [Request.model_validate_json(r.data) for r in requests_db if r.request_id in blocked_ids] @@ -497,7 +504,7 @@ async def mark_request_as_handled(self, request: Request) -> ProcessedRequest | .where(self._ITEM_TABLE.request_queue_id == self._id, self._ITEM_TABLE.request_id == request_id) .values(is_handled=True, time_blocked_until=None, client_key=None, data=request.model_dump_json()) ) - async with self.get_session() as session: + async with self.get_session(with_simple_commit=True) as session: result = await session.execute(stmt) result = cast('CursorResult', result) if not isinstance(result, CursorResult) else result @@ -505,17 +512,9 @@ async def mark_request_as_handled(self, request: Request) -> ProcessedRequest | logger.warning(f'Request {request.unique_key} not found in database.') return None - await self._update_metadata( - session, - **_QueueMetadataUpdateParams( - delta_handled_request_count=1, - delta_pending_request_count=-1, - update_modified_at=True, - update_accessed_at=True, - force=True, - ), + await self._add_buffer_record( + session, update_modified_at=True, delta_pending_request_count=-1, delta_handled_request_count=1 ) - await session.commit() return ProcessedRequest( unique_key=request.unique_key, was_already_present=True, @@ -567,9 +566,7 @@ async def reclaim_request( if result.rowcount == 0: logger.warning(f'Request {request.unique_key} not found in database.') return None - await self._update_metadata( - session, **_QueueMetadataUpdateParams(update_modified_at=True, update_accessed_at=True) - ) + await self._add_buffer_record(session, update_modified_at=True) # put the forefront request at the beginning of the cache if forefront: @@ -587,30 +584,43 @@ async def is_empty(self) -> bool: if self._pending_fetch_cache: return False - # Check database for unhandled requests - async with self.get_session() as session: - metadata_orm = await session.get(self._METADATA_TABLE, self._id) - if not metadata_orm: - raise ValueError(f'Request queue with ID "{self._id}" not found.') + metadata = await self.get_metadata() - empty = metadata_orm.pending_request_count == 0 + async with self.get_session(with_simple_commit=True) as session: + # If there are no pending requests, check if there are any buffered updates + if metadata.pending_request_count == 0: + # Check for active buffer lock (indicates pending buffer processing) + buffer_lock_stmt = select(self._METADATA_TABLE.buffer_locked_until).where( + self._METADATA_TABLE.id == self._id + ) + buffer_lock_result = await session.execute(buffer_lock_stmt) + buffer_locked_until = buffer_lock_result.scalar() + + # If buffer is locked, there are pending updates being processed + if buffer_locked_until is not None: + await self._add_buffer_record(session) + return False + + # Check if there are any buffered updates that might change the pending count + buffer_check_stmt = select( + exists().where( + (self._BUFFER_TABLE.storage_id == self._id) + & ( + (self._BUFFER_TABLE.delta_pending_count != 0) | (self._BUFFER_TABLE.need_recalc == True) # noqa: E712 + ) + ) + ) + buffer_result = await session.execute(buffer_check_stmt) + has_pending_buffer_updates = buffer_result.scalar() - updated = await self._update_metadata( - session, - **_QueueMetadataUpdateParams( - update_accessed_at=True, - # With multi-client access, counters may become out of sync. - # If the queue is not empty, we perform a recalculation to synchronize the counters in the metadata. - recalculate=not empty, - update_modified_at=not empty, - ), - ) + await self._add_buffer_record(session) + # If there are no pending requests and no buffered updates, the queue is empty + return not has_pending_buffer_updates - # Commit updates to the metadata - if updated: - await session.commit() + # There are pending requests (may be inaccurate), ensure recalculated metadata + await self._add_buffer_record(session, update_modified_at=True, recalculate=True) - return empty + return False async def _get_state(self, session: AsyncSession) -> RequestQueueStateDb: """Get the current state of the request queue.""" @@ -628,6 +638,7 @@ async def _get_state(self, session: AsyncSession) -> RequestQueueStateDb: raise RuntimeError(f'Failed to create or retrieve state for queue {self._id}') return orm_state + @override def _specific_update_metadata( self, new_handled_request_count: int | None = None, @@ -635,6 +646,7 @@ def _specific_update_metadata( new_total_request_count: int | None = None, delta_handled_request_count: int | None = None, delta_pending_request_count: int | None = None, + delta_total_request_count: int | None = None, *, recalculate: bool = False, update_had_multiple_clients: bool = False, @@ -649,6 +661,7 @@ def _specific_update_metadata( new_total_request_count: If provided, update the total_request_count to this value. delta_handled_request_count: If provided, add this value to the handled_request_count. delta_pending_request_count: If provided, add this value to the pending_request_count. + delta_total_request_count: If provided, add this value to the total_request_count. recalculate: If True, recalculate the pending_request_count, and total_request_count on request table. update_had_multiple_clients: If True, set had_multiple_clients to True. """ @@ -657,23 +670,6 @@ def _specific_update_metadata( if update_had_multiple_clients: values_to_set['had_multiple_clients'] = True - if new_handled_request_count is not None: - values_to_set['handled_request_count'] = new_handled_request_count - elif delta_handled_request_count is not None: - values_to_set['handled_request_count'] = ( - self._METADATA_TABLE.handled_request_count + delta_handled_request_count - ) - - if new_pending_request_count is not None: - values_to_set['pending_request_count'] = new_pending_request_count - elif delta_pending_request_count is not None: - values_to_set['pending_request_count'] = ( - self._METADATA_TABLE.pending_request_count + delta_pending_request_count - ) - - if new_total_request_count is not None: - values_to_set['total_request_count'] = new_total_request_count - if recalculate: stmt = ( update(self._METADATA_TABLE) @@ -702,6 +698,28 @@ def _specific_update_metadata( values_to_set['custom_stmt'] = stmt + else: + if new_handled_request_count is not None: + values_to_set['handled_request_count'] = new_handled_request_count + elif delta_handled_request_count is not None: + values_to_set['handled_request_count'] = ( + self._METADATA_TABLE.handled_request_count + delta_handled_request_count + ) + + if new_pending_request_count is not None: + values_to_set['pending_request_count'] = new_pending_request_count + elif delta_pending_request_count is not None: + values_to_set['pending_request_count'] = ( + self._METADATA_TABLE.pending_request_count + delta_pending_request_count + ) + + if new_total_request_count is not None: + values_to_set['total_request_count'] = new_total_request_count + elif delta_total_request_count is not None: + values_to_set['total_request_count'] = ( + self._METADATA_TABLE.total_request_count + delta_total_request_count + ) + return values_to_set @staticmethod @@ -718,3 +736,82 @@ def _get_int_id_from_unique_key(unique_key: str) -> int: hashed_key = sha256(unique_key.encode('utf-8')).hexdigest() name_length = 15 return int(hashed_key[:name_length], 16) + + @override + def _prepare_buffer_data( + self, + delta_handled_request_count: int | None = None, + delta_pending_request_count: int | None = None, + delta_total_request_count: int | None = None, + *, + recalculate: bool = False, + **_kwargs: Any, + ) -> dict[str, Any]: + """Prepare request queue specific buffer data. + + Args: + delta_handled_request_count: If provided, add this value to the handled_request_count. + delta_pending_request_count: If provided, add this value to the pending_request_count. + delta_total_request_count: If provided, add this value to the total_request_count. + recalculate: If True, recalculate the pending_request_count, and total_request_count on request table. + """ + buffer_data: dict[str, Any] = { + 'client_id': self.client_key, + } + + if delta_handled_request_count: + buffer_data['delta_handled_count'] = delta_handled_request_count + + if delta_pending_request_count: + buffer_data['delta_pending_count'] = delta_pending_request_count + + if delta_total_request_count: + buffer_data['delta_total_count'] = delta_total_request_count + + if recalculate: + buffer_data['need_recalc'] = True + + return buffer_data + + @override + async def _apply_buffer_updates(self, session: AsyncSession, max_buffer_id: int) -> None: + aggregations: list[ColumnElement[Any]] = [ + sql_func.max(self._BUFFER_TABLE.accessed_at).label('max_accessed_at'), + sql_func.max(self._BUFFER_TABLE.modified_at).label('max_modified_at'), + sql_func.sum(self._BUFFER_TABLE.delta_handled_count).label('delta_handled_count'), + sql_func.sum(self._BUFFER_TABLE.delta_pending_count).label('delta_pending_count'), + sql_func.sum(self._BUFFER_TABLE.delta_total_count).label('delta_total_count'), + ] + + if not self._had_multiple_clients: + aggregations.append( + sql_func.count(sql_func.distinct(self._BUFFER_TABLE.client_id)).label('unique_clients_count') + ) + + if self._storage_client.get_dialect_name() == 'postgresql': + aggregations.append(sql_func.bool_or(self._BUFFER_TABLE.need_recalc).label('need_recalc')) + else: + aggregations.append(sql_func.max(self._BUFFER_TABLE.need_recalc).label('need_recalc')) + + aggregation_stmt = select(*aggregations).where( + self._BUFFER_TABLE.storage_id == self._id, self._BUFFER_TABLE.id <= max_buffer_id + ) + + result = await session.execute(aggregation_stmt) + row = result.first() + + if not row: + return + + await self._update_metadata( + session, + **_QueueMetadataUpdateParams( + accessed_at=row.max_accessed_at, + modified_at=row.max_modified_at, + update_had_multiple_clients=not self._had_multiple_clients and row.unique_clients_count > 1, + delta_handled_request_count=row.delta_handled_count, + delta_pending_request_count=row.delta_pending_count, + delta_total_request_count=row.delta_total_count, + recalculate=row.need_recalc, + ), + ) diff --git a/src/crawlee/storage_clients/_sql/_storage_client.py b/src/crawlee/storage_clients/_sql/_storage_client.py index d324a17a86..54209d4302 100644 --- a/src/crawlee/storage_clients/_sql/_storage_client.py +++ b/src/crawlee/storage_clients/_sql/_storage_client.py @@ -2,7 +2,7 @@ import sys import warnings -from datetime import timedelta +from logging import getLogger from pathlib import Path from typing import TYPE_CHECKING @@ -26,6 +26,9 @@ from sqlalchemy.ext.asyncio import AsyncSession +logger = getLogger(__name__) + + @docs_group('Storage clients') class SqlStorageClient(StorageClient): """SQL implementation of the storage client. @@ -68,9 +71,6 @@ def __init__( self._initialized = False self.session_maker: None | async_sessionmaker[AsyncSession] = None - # Minimum interval to reduce database load from frequent concurrent metadata updates - self._accessed_modified_update_interval = timedelta(seconds=1) - # Flag needed to apply optimizations only for default database self._default_flag = self._engine is None and self._connection_string is None self._dialect_name: str | None = None @@ -106,10 +106,6 @@ def get_dialect_name(self) -> str | None: """Get the database dialect name.""" return self._dialect_name - def get_accessed_modified_update_interval(self) -> timedelta: - """Get the interval for accessed and modified updates.""" - return self._accessed_modified_update_interval - async def initialize(self, configuration: Configuration) -> None: """Initialize the database schema. @@ -140,9 +136,7 @@ async def initialize(self, configuration: Configuration) -> None: await conn.execute(text('PRAGMA mmap_size=268435456')) # 256MB memory mapping await conn.execute(text('PRAGMA foreign_keys=ON')) # Enforce constraints await conn.execute(text('PRAGMA busy_timeout=30000')) # 30s busy timeout - await conn.run_sync(Base.metadata.create_all, checkfirst=True) - from crawlee import __version__ # Noqa: PLC0415 db_version = (await conn.execute(select(VersionDb))).scalar_one_or_none() @@ -158,7 +152,6 @@ async def initialize(self, configuration: Configuration) -> None: ) elif not db_version: await conn.execute(insert(VersionDb).values(version=__version__)) - except (IntegrityError, OperationalError): await conn.rollback() @@ -280,10 +273,9 @@ def _get_or_create_engine(self, configuration: Configuration) -> AsyncEngine: self._engine = create_async_engine( connection_string, future=True, - pool_size=5, - max_overflow=10, - pool_timeout=30, - pool_recycle=600, + pool_size=10, + max_overflow=50, + pool_timeout=60, pool_pre_ping=True, echo=False, connect_args={'timeout': 30}, diff --git a/tests/unit/storage_clients/_sql/test_sql_dataset_client.py b/tests/unit/storage_clients/_sql/test_sql_dataset_client.py index 5ad4448d4c..f802b35c62 100644 --- a/tests/unit/storage_clients/_sql/test_sql_dataset_client.py +++ b/tests/unit/storage_clients/_sql/test_sql_dataset_client.py @@ -1,7 +1,6 @@ from __future__ import annotations import asyncio -from datetime import timedelta from typing import TYPE_CHECKING import pytest @@ -38,11 +37,9 @@ def get_tables(sync_conn: Connection) -> list[str]: @pytest.fixture async def dataset_client( configuration: Configuration, - monkeypatch: pytest.MonkeyPatch, ) -> AsyncGenerator[SqlDatasetClient, None]: """A fixture for a SQL dataset client.""" async with SqlStorageClient() as storage_client: - monkeypatch.setattr(storage_client, '_accessed_modified_update_interval', timedelta(seconds=0)) client = await storage_client.create_dataset_client( name='test-dataset', configuration=configuration, diff --git a/tests/unit/storage_clients/_sql/test_sql_kvs_client.py b/tests/unit/storage_clients/_sql/test_sql_kvs_client.py index 89ecc891c4..69a71b0aa8 100644 --- a/tests/unit/storage_clients/_sql/test_sql_kvs_client.py +++ b/tests/unit/storage_clients/_sql/test_sql_kvs_client.py @@ -2,7 +2,6 @@ import asyncio import json -from datetime import timedelta from typing import TYPE_CHECKING import pytest @@ -34,16 +33,13 @@ def configuration(tmp_path: Path) -> Configuration: @pytest.fixture async def kvs_client( configuration: Configuration, - monkeypatch: pytest.MonkeyPatch, ) -> AsyncGenerator[SqlKeyValueStoreClient, None]: """A fixture for a SQL key-value store client.""" async with SqlStorageClient() as storage_client: - monkeypatch.setattr(storage_client, '_accessed_modified_update_interval', timedelta(seconds=0)) client = await storage_client.create_kvs_client( name='test-kvs', configuration=configuration, ) - monkeypatch.setattr(client, '_accessed_modified_update_interval', timedelta(seconds=0)) yield client await client.drop() diff --git a/tests/unit/storage_clients/_sql/test_sql_rq_client.py b/tests/unit/storage_clients/_sql/test_sql_rq_client.py index c98b7a1fc0..1139d30bbf 100644 --- a/tests/unit/storage_clients/_sql/test_sql_rq_client.py +++ b/tests/unit/storage_clients/_sql/test_sql_rq_client.py @@ -2,7 +2,6 @@ import asyncio import json -from datetime import timedelta from typing import TYPE_CHECKING import pytest @@ -35,16 +34,13 @@ def configuration(tmp_path: Path) -> Configuration: @pytest.fixture async def rq_client( configuration: Configuration, - monkeypatch: pytest.MonkeyPatch, ) -> AsyncGenerator[SqlRequestQueueClient, None]: """A fixture for a SQL request queue client.""" async with SqlStorageClient() as storage_client: - monkeypatch.setattr(storage_client, '_accessed_modified_update_interval', timedelta(seconds=0)) client = await storage_client.create_rq_client( name='test-request-queue', configuration=configuration, ) - monkeypatch.setattr(client, '_accessed_modified_update_interval', timedelta(seconds=0)) yield client await client.drop()