diff --git a/docs/guides/storage_clients.mdx b/docs/guides/storage_clients.mdx
index 70c4964192..9e777782f4 100644
--- a/docs/guides/storage_clients.mdx
+++ b/docs/guides/storage_clients.mdx
@@ -208,6 +208,15 @@ class dataset_records {
+ data
}
+class dataset_metadata_buffer {
+ <
>
+ + id (PK)
+ + accessed_at
+ + modified_at
+ + dataset_id (FK)
+ + delta_item_count
+}
+
%% ========================
%% Key-Value Store Tables
%% ========================
@@ -231,15 +240,25 @@ class key_value_store_records {
+ size
}
+class key_value_store_metadata_buffer {
+ <>
+ + id (PK)
+ + accessed_at
+ + modified_at
+ + key_value_store_id (FK)
+}
+
%% ========================
%% Client to Table arrows
%% ========================
SqlDatasetClient --> datasets
SqlDatasetClient --> dataset_records
+SqlDatasetClient --> dataset_metadata_buffer
SqlKeyValueStoreClient --> key_value_stores
SqlKeyValueStoreClient --> key_value_store_records
+SqlKeyValueStoreClient --> key_value_store_metadata_buffer
```
```mermaid
---
@@ -294,6 +313,19 @@ class request_queue_state {
+ forefront_sequence_counter
}
+class request_queue_metadata_buffer {
+ <>
+ + id (PK)
+ + accessed_at
+ + modified_at
+ + request_queue_id (FK)
+ + client_id
+ + delta_handled_count
+ + delta_pending_count
+ + delta_total_count
+ + need_recalc
+}
+
%% ========================
%% Client to Table arrows
%% ========================
@@ -301,6 +333,7 @@ class request_queue_state {
SqlRequestQueueClient --> request_queues
SqlRequestQueueClient --> request_queue_records
SqlRequestQueueClient --> request_queue_state
+SqlRequestQueueClient --> request_queue_metadata_buffer
```
Configuration options for the `SqlStorageClient` can be set through environment variables or the `Configuration` class:
diff --git a/src/crawlee/storage_clients/_sql/_client_mixin.py b/src/crawlee/storage_clients/_sql/_client_mixin.py
index c681e3a220..234a0abdad 100644
--- a/src/crawlee/storage_clients/_sql/_client_mixin.py
+++ b/src/crawlee/storage_clients/_sql/_client_mixin.py
@@ -2,11 +2,12 @@
from abc import ABC, abstractmethod
from contextlib import asynccontextmanager
-from datetime import datetime, timezone
+from datetime import datetime, timedelta, timezone
from logging import getLogger
from typing import TYPE_CHECKING, Any, ClassVar, TypedDict, cast, overload
-from sqlalchemy import delete, select, text, update
+from sqlalchemy import CursorResult, delete, select, text, update
+from sqlalchemy import func as sql_func
from sqlalchemy.dialects.postgresql import insert as pg_insert
from sqlalchemy.dialects.sqlite import insert as lite_insert
from sqlalchemy.exc import SQLAlchemyError
@@ -25,10 +26,13 @@
from ._db_models import (
DatasetItemDb,
+ DatasetMetadataBufferDb,
DatasetMetadataDb,
+ KeyValueStoreMetadataBufferDb,
KeyValueStoreMetadataDb,
KeyValueStoreRecordDb,
RequestDb,
+ RequestQueueMetadataBufferDb,
RequestQueueMetadataDb,
)
from ._storage_client import SqlStorageClient
@@ -40,9 +44,8 @@
class MetadataUpdateParams(TypedDict, total=False):
"""Parameters for updating metadata."""
- update_accessed_at: NotRequired[bool]
- update_modified_at: NotRequired[bool]
- force: NotRequired[bool]
+ accessed_at: NotRequired[datetime]
+ modified_at: NotRequired[datetime]
class SqlClientMixin(ABC):
@@ -57,21 +60,24 @@ class SqlClientMixin(ABC):
_METADATA_TABLE: ClassVar[type[DatasetMetadataDb | KeyValueStoreMetadataDb | RequestQueueMetadataDb]]
"""SQLAlchemy model for metadata."""
+ _BUFFER_TABLE: ClassVar[
+ type[KeyValueStoreMetadataBufferDb | DatasetMetadataBufferDb | RequestQueueMetadataBufferDb]
+ ]
+ """SQLAlchemy model for metadata buffer."""
+
_ITEM_TABLE: ClassVar[type[DatasetItemDb | KeyValueStoreRecordDb | RequestDb]]
"""SQLAlchemy model for items."""
_CLIENT_TYPE: ClassVar[str]
"""Human-readable client type for error messages."""
+ _BLOCK_BUFFER_TIME = timedelta(seconds=1)
+ """Time interval that blocks buffer reading to update metadata."""
+
def __init__(self, *, id: str, storage_client: SqlStorageClient) -> None:
self._id = id
self._storage_client = storage_client
- # Time tracking to reduce database writes during frequent operation
- self._accessed_at_allow_update_after: datetime | None = None
- self._modified_at_allow_update_after: datetime | None = None
- self._accessed_modified_update_interval = storage_client.get_accessed_modified_update_interval()
-
@classmethod
async def _open(
cls,
@@ -109,7 +115,9 @@ async def _open(
if orm_metadata:
client = cls(id=orm_metadata.id, storage_client=storage_client)
- await client._update_metadata(session, update_accessed_at=True)
+ await client._add_buffer_record(session)
+ # Ensure any pending buffer updates are processed
+ await client._process_buffers()
else:
now = datetime.now(timezone.utc)
metadata = metadata_model(
@@ -121,8 +129,6 @@ async def _open(
**extra_metadata_fields,
)
client = cls(id=metadata.id, storage_client=storage_client)
- client._accessed_at_allow_update_after = now + client._accessed_modified_update_interval
- client._modified_at_allow_update_after = now + client._accessed_modified_update_interval
session.add(cls._METADATA_TABLE(**metadata.model_dump(), internal_name=internal_name))
return client
@@ -262,9 +268,12 @@ async def _purge(self, metadata_kwargs: MetadataUpdateParams) -> None:
Args:
metadata_kwargs: Arguments to pass to _update_metadata.
"""
- stmt = delete(self._ITEM_TABLE).where(self._ITEM_TABLE.storage_id == self._id)
+ # Process buffers to ensure metadata is up to date before purging
+ await self._process_buffers()
+
+ stmt_records = delete(self._ITEM_TABLE).where(self._ITEM_TABLE.storage_id == self._id)
async with self.get_session(with_simple_commit=True) as session:
- await session.execute(stmt)
+ await session.execute(stmt_records)
await self._update_metadata(session, **metadata_kwargs)
async def _drop(self) -> None:
@@ -290,6 +299,9 @@ async def _get_metadata(
self, metadata_model: type[DatasetMetadata | KeyValueStoreMetadata | RequestQueueMetadata]
) -> DatasetMetadata | KeyValueStoreMetadata | RequestQueueMetadata:
"""Retrieve client metadata."""
+ # Process any pending buffer updates first
+ await self._process_buffers()
+
async with self.get_session() as session:
orm_metadata = await session.get(self._METADATA_TABLE, self._id)
if not orm_metadata:
@@ -297,46 +309,6 @@ async def _get_metadata(
return metadata_model.model_validate(orm_metadata)
- def _default_update_metadata(
- self, *, update_accessed_at: bool = False, update_modified_at: bool = False, force: bool = False
- ) -> dict[str, Any]:
- """Prepare common metadata updates with rate limiting.
-
- Args:
- update_accessed_at: Whether to update accessed_at timestamp.
- update_modified_at: Whether to update modified_at timestamp.
- force: Whether to force the update regardless of rate limiting.
- """
- values_to_set: dict[str, Any] = {}
- now = datetime.now(timezone.utc)
-
- # If the record must be updated (for example, when updating counters), we update timestamps and shift the time.
- if force:
- if update_modified_at:
- values_to_set['modified_at'] = now
- self._modified_at_allow_update_after = now + self._accessed_modified_update_interval
- if update_accessed_at:
- values_to_set['accessed_at'] = now
- self._accessed_at_allow_update_after = now + self._accessed_modified_update_interval
-
- elif update_modified_at and (
- self._modified_at_allow_update_after is None or now >= self._modified_at_allow_update_after
- ):
- values_to_set['modified_at'] = now
- self._modified_at_allow_update_after = now + self._accessed_modified_update_interval
- # The record will be updated, we can update `accessed_at` and shift the time.
- if update_accessed_at:
- values_to_set['accessed_at'] = now
- self._accessed_at_allow_update_after = now + self._accessed_modified_update_interval
-
- elif update_accessed_at and (
- self._accessed_at_allow_update_after is None or now >= self._accessed_at_allow_update_after
- ):
- values_to_set['accessed_at'] = now
- self._accessed_at_allow_update_after = now + self._accessed_modified_update_interval
-
- return values_to_set
-
@abstractmethod
def _specific_update_metadata(self, **kwargs: Any) -> dict[str, Any]:
"""Prepare storage-specific metadata updates.
@@ -347,30 +319,42 @@ def _specific_update_metadata(self, **kwargs: Any) -> dict[str, Any]:
**kwargs: Storage-specific update parameters.
"""
+ @abstractmethod
+ def _prepare_buffer_data(self, **kwargs: Any) -> dict[str, Any]:
+ """Prepare storage-specific buffer data. Must be implemented by concrete classes."""
+
+ @abstractmethod
+ async def _apply_buffer_updates(self, session: AsyncSession, max_buffer_id: int) -> None:
+ """Apply aggregated buffer updates to metadata. Must be implemented by concrete classes.
+
+ Args:
+ session: Active database session.
+ max_buffer_id: Maximum buffer record ID to process.
+ """
+
async def _update_metadata(
self,
session: AsyncSession,
*,
- update_accessed_at: bool = False,
- update_modified_at: bool = False,
- force: bool = False,
+ accessed_at: datetime | None = None,
+ modified_at: datetime | None = None,
**kwargs: Any,
- ) -> bool:
- """Update storage metadata combining common and specific fields.
+ ) -> None:
+ """Directly update storage metadata combining common and specific fields.
Args:
session: Active database session.
- update_accessed_at: Whether to update accessed_at timestamp.
- update_modified_at: Whether to update modified_at timestamp.
- force: Whether to force the update timestamps regardless of rate limiting.
+ accessed_at: Datetime to set as accessed_at timestamp.
+ modified_at: Datetime to set as modified_at timestamp.
**kwargs: Additional arguments for _specific_update_metadata.
-
- Returns:
- True if any updates were made, False otherwise
"""
- values_to_set = self._default_update_metadata(
- update_accessed_at=update_accessed_at, update_modified_at=update_modified_at, force=force
- )
+ values_to_set: dict[str, Any] = {}
+
+ if accessed_at is not None:
+ values_to_set['accessed_at'] = accessed_at
+
+ if modified_at is not None:
+ values_to_set['modified_at'] = modified_at
values_to_set.update(self._specific_update_metadata(**kwargs))
@@ -380,6 +364,141 @@ async def _update_metadata(
stmt = stmt.values(**values_to_set)
await session.execute(stmt)
+
+ async def _add_buffer_record(
+ self,
+ session: AsyncSession,
+ *,
+ update_modified_at: bool = False,
+ **kwargs: Any,
+ ) -> None:
+ """Add a record to the buffer table and update metadata.
+
+ Args:
+ session: Active database session.
+ update_modified_at: Whether to update modified_at timestamp.
+ **kwargs: Additional arguments for _prepare_buffer_data.
+ """
+ now = datetime.now(timezone.utc)
+ values_to_set = {
+ 'storage_id': self._id,
+ 'accessed_at': now, # All entries in the buffer require updating `accessed_at`
+ 'modified_at': now if update_modified_at else None,
+ }
+ values_to_set.update(self._prepare_buffer_data(**kwargs))
+
+ session.add(self._BUFFER_TABLE(**values_to_set))
+
+ async def _try_acquire_buffer_lock(self, session: AsyncSession) -> bool:
+ """Try to acquire buffer processing lock for 200ms.
+
+ Args:
+ session: Active database session.
+
+ Returns:
+ True if lock was acquired, False if already locked by another process.
+ """
+ now = datetime.now(timezone.utc)
+ lock_until = now + self._BLOCK_BUFFER_TIME
+ dialect = self._storage_client.get_dialect_name()
+
+ if dialect == 'postgresql':
+ select_stmt = (
+ select(self._METADATA_TABLE)
+ .where(
+ self._METADATA_TABLE.id == self._id,
+ (self._METADATA_TABLE.buffer_locked_until.is_(None))
+ | (self._METADATA_TABLE.buffer_locked_until < now),
+ select(self._BUFFER_TABLE.id).where(self._BUFFER_TABLE.storage_id == self._id).exists(),
+ )
+ .with_for_update(skip_locked=True)
+ )
+ result = await session.execute(select_stmt)
+ metadata_row = result.scalar_one_or_none()
+
+ if metadata_row is None:
+ # Either conditions not met OR row is locked by another process
+ return False
+
+ # Acquire lock only if not currently locked or lock has expired
+ update_stmt = (
+ update(self._METADATA_TABLE)
+ .where(
+ self._METADATA_TABLE.id == self._id,
+ (self._METADATA_TABLE.buffer_locked_until.is_(None)) | (self._METADATA_TABLE.buffer_locked_until < now),
+ select(self._BUFFER_TABLE.id).where(self._BUFFER_TABLE.storage_id == self._id).exists(),
+ )
+ .values(buffer_locked_until=lock_until)
+ )
+
+ result = await session.execute(update_stmt)
+ result = cast('CursorResult', result) if not isinstance(result, CursorResult) else result
+
+ if result.rowcount > 0:
+ await session.flush()
return True
return False
+
+ async def _release_buffer_lock(self, session: AsyncSession) -> None:
+ """Release buffer processing lock by setting buffer_locked_until to NULL.
+
+ Args:
+ session: Active database session.
+ """
+ stmt = update(self._METADATA_TABLE).where(self._METADATA_TABLE.id == self._id).values(buffer_locked_until=None)
+
+ await session.execute(stmt)
+
+ async def _has_pending_buffer_updates(self, session: AsyncSession) -> bool:
+ """Check if there are pending buffer updates not yet applied to metadata.
+
+ Returns False only when buffer_locked_until is NULL (metadata is consistent).
+
+ Returns:
+ True if metadata might be inconsistent due to pending buffer updates.
+ """
+ result = await session.execute(
+ select(self._METADATA_TABLE.buffer_locked_until).where(self._METADATA_TABLE.id == self._id)
+ )
+
+ locked_until = result.scalar()
+
+ # Any non-NULL value means there are pending updates
+ return locked_until is not None
+
+ async def _process_buffers(self) -> None:
+ """Process pending buffer updates and apply them to metadata."""
+ async with self.get_session(with_simple_commit=True) as session:
+ # Try to acquire buffer processing lock
+ if not await self._try_acquire_buffer_lock(session):
+ # Another process is currently processing buffers or lock acquisition failed
+ return
+
+ # Get the maximum buffer ID at this moment
+ # This creates a consistent snapshot - records added during processing won't be included
+ max_buffer_id_stmt = select(sql_func.max(self._BUFFER_TABLE.id)).where(
+ self._BUFFER_TABLE.storage_id == self._id
+ )
+
+ result = await session.execute(max_buffer_id_stmt)
+ max_buffer_id = result.scalar()
+
+ if max_buffer_id is None:
+ # No buffer records to process. Release the lock and exit.
+ await self._release_buffer_lock(session)
+ return
+
+ # Apply aggregated buffer updates to metadata using only records <= max_buffer_id
+ # This method is implemented by concrete storage classes
+ await self._apply_buffer_updates(session, max_buffer_id=max_buffer_id)
+
+ # Clean up only the processed buffer records (those <= max_buffer_id)
+ delete_stmt = delete(self._BUFFER_TABLE).where(
+ self._BUFFER_TABLE.storage_id == self._id, self._BUFFER_TABLE.id <= max_buffer_id
+ )
+
+ await session.execute(delete_stmt)
+
+ # Release the lock after successful processing
+ await self._release_buffer_lock(session)
diff --git a/src/crawlee/storage_clients/_sql/_dataset_client.py b/src/crawlee/storage_clients/_sql/_dataset_client.py
index 61873975c8..1ea9d6b7cb 100644
--- a/src/crawlee/storage_clients/_sql/_dataset_client.py
+++ b/src/crawlee/storage_clients/_sql/_dataset_client.py
@@ -1,21 +1,24 @@
from __future__ import annotations
+from datetime import datetime, timezone
from logging import getLogger
from typing import TYPE_CHECKING, Any
from sqlalchemy import Select, insert, select
+from sqlalchemy import func as sql_func
from typing_extensions import Self, override
from crawlee.storage_clients._base import DatasetClient
from crawlee.storage_clients.models import DatasetItemsListPage, DatasetMetadata
from ._client_mixin import MetadataUpdateParams, SqlClientMixin
-from ._db_models import DatasetItemDb, DatasetMetadataDb
+from ._db_models import DatasetItemDb, DatasetMetadataBufferDb, DatasetMetadataDb
if TYPE_CHECKING:
from collections.abc import AsyncIterator
from sqlalchemy import Select
+ from sqlalchemy.ext.asyncio import AsyncSession
from typing_extensions import NotRequired
from ._storage_client import SqlStorageClient
@@ -40,6 +43,7 @@ class SqlDatasetClient(DatasetClient, SqlClientMixin):
The dataset data is stored in SQL database tables following the pattern:
- `datasets` table: Contains dataset metadata (id, name, timestamps, item_count)
- `dataset_records` table: Contains individual items with JSON data and auto-increment ordering
+ - `dataset_metadata_buffer` table: Buffers metadata updates for performance optimization
Items are stored as a JSON object in SQLite and as JSONB in PostgreSQL. These objects must be JSON-serializable.
The `item_id` auto-increment primary key ensures insertion order is preserved.
@@ -58,6 +62,9 @@ class SqlDatasetClient(DatasetClient, SqlClientMixin):
_CLIENT_TYPE = 'Dataset'
"""Human-readable client type for error messages."""
+ _BUFFER_TABLE = DatasetMetadataBufferDb
+ """SQLAlchemy model for metadata buffer."""
+
def __init__(
self,
*,
@@ -121,12 +128,12 @@ async def purge(self) -> None:
Resets item_count to 0 and deletes all records from dataset_records table.
"""
+ now = datetime.now(timezone.utc)
await self._purge(
metadata_kwargs=_DatasetMetadataUpdateParams(
new_item_count=0,
- update_accessed_at=True,
- update_modified_at=True,
- force=True,
+ accessed_at=now,
+ modified_at=now,
)
)
@@ -135,23 +142,13 @@ async def push_data(self, data: list[dict[str, Any]] | dict[str, Any]) -> None:
if not isinstance(data, list):
data = [data]
- db_items: list[dict[str, Any]] = []
db_items = [{'dataset_id': self._id, 'data': item} for item in data]
stmt = insert(self._ITEM_TABLE).values(db_items)
async with self.get_session(with_simple_commit=True) as session:
await session.execute(stmt)
- await self._update_metadata(
- session,
- **_DatasetMetadataUpdateParams(
- update_accessed_at=True,
- update_modified_at=True,
- delta_item_count=len(data),
- new_item_count=len(data),
- force=True,
- ),
- )
+ await self._add_buffer_record(session, update_modified_at=True, delta_item_count=len(data))
@override
async def get_data(
@@ -183,15 +180,11 @@ async def get_data(
view=view,
)
- async with self.get_session() as session:
+ async with self.get_session(with_simple_commit=True) as session:
result = await session.execute(stmt)
db_items = result.scalars().all()
- updated = await self._update_metadata(session, **_DatasetMetadataUpdateParams(update_accessed_at=True))
-
- # Commit updates to the metadata
- if updated:
- await session.commit()
+ await self._add_buffer_record(session)
items = [db_item.data for db_item in db_items]
metadata = await self.get_metadata()
@@ -230,17 +223,13 @@ async def iterate_items(
skip_hidden=skip_hidden,
)
- async with self.get_session() as session:
+ async with self.get_session(with_simple_commit=True) as session:
db_items = await session.stream_scalars(stmt)
async for db_item in db_items:
yield db_item.data
- updated = await self._update_metadata(session, **_DatasetMetadataUpdateParams(update_accessed_at=True))
-
- # Commit updates to the metadata
- if updated:
- await session.commit()
+ await self._add_buffer_record(session)
def _prepare_get_stmt(
self,
@@ -286,13 +275,14 @@ def _prepare_get_stmt(
return stmt.offset(offset).limit(limit)
+ @override
def _specific_update_metadata(
self,
new_item_count: int | None = None,
delta_item_count: int | None = None,
**_kwargs: dict[str, Any],
) -> dict[str, Any]:
- """Update the dataset metadata in the database.
+ """Directly update the dataset metadata in the database.
Args:
session: The SQLAlchemy AsyncSession to use for the update.
@@ -308,3 +298,39 @@ def _specific_update_metadata(
values_to_set['item_count'] = self._METADATA_TABLE.item_count + delta_item_count
return values_to_set
+
+ @override
+ def _prepare_buffer_data(self, delta_item_count: int | None = None, **_kwargs: Any) -> dict[str, Any]:
+ """Prepare dataset specific buffer data.
+
+ Args:
+ delta_item_count: If provided, add this value to the current item count.
+ """
+ buffer_data = {}
+ if delta_item_count is not None:
+ buffer_data['delta_item_count'] = delta_item_count
+
+ return buffer_data
+
+ @override
+ async def _apply_buffer_updates(self, session: AsyncSession, max_buffer_id: int) -> None:
+ aggregation_stmt = select(
+ sql_func.max(self._BUFFER_TABLE.accessed_at).label('max_accessed_at'),
+ sql_func.max(self._BUFFER_TABLE.modified_at).label('max_modified_at'),
+ sql_func.sum(self._BUFFER_TABLE.delta_item_count).label('delta_item_count'),
+ ).where(self._BUFFER_TABLE.storage_id == self._id, self._BUFFER_TABLE.id <= max_buffer_id)
+
+ result = await session.execute(aggregation_stmt)
+ row = result.first()
+
+ if not row:
+ return
+
+ await self._update_metadata(
+ session,
+ **_DatasetMetadataUpdateParams(
+ accessed_at=row.max_accessed_at,
+ modified_at=row.max_modified_at,
+ delta_item_count=row.delta_item_count,
+ ),
+ )
diff --git a/src/crawlee/storage_clients/_sql/_db_models.py b/src/crawlee/storage_clients/_sql/_db_models.py
index 2a8f8b565b..00f38f200b 100644
--- a/src/crawlee/storage_clients/_sql/_db_models.py
+++ b/src/crawlee/storage_clients/_sql/_db_models.py
@@ -52,10 +52,10 @@ class Base(DeclarativeBase):
class StorageMetadataDb:
"""Base database model for storage metadata."""
- internal_name: Mapped[str] = mapped_column(String, nullable=False, index=True, unique=True)
+ internal_name: Mapped[str] = mapped_column(String(255), nullable=False, index=True, unique=True)
"""Internal unique name for a storage instance based on a name or alias."""
- name: Mapped[str | None] = mapped_column(String, nullable=True, unique=True)
+ name: Mapped[str | None] = mapped_column(String(255), nullable=True, unique=True)
"""Human-readable name. None becomes 'default' in database to enforce uniqueness."""
accessed_at: Mapped[datetime] = mapped_column(AwareDateTime, nullable=False)
@@ -67,6 +67,9 @@ class StorageMetadataDb:
modified_at: Mapped[datetime] = mapped_column(AwareDateTime, nullable=False)
"""Last modification datetime."""
+ buffer_locked_until: Mapped[datetime | None] = mapped_column(AwareDateTime, nullable=True)
+ """Timestamp until which buffer processing is locked for this storage. NULL = unlocked."""
+
class DatasetMetadataDb(StorageMetadataDb, Base):
"""Metadata table for datasets."""
@@ -84,6 +87,11 @@ class DatasetMetadataDb(StorageMetadataDb, Base):
back_populates='dataset', cascade='all, delete-orphan', lazy='noload'
)
+ # Relationship to metadata buffer updates
+ buffer: Mapped[list[DatasetMetadataBufferDb]] = relationship(
+ back_populates='dataset', cascade='all, delete-orphan', lazy='noload'
+ )
+
id = synonym('dataset_id')
"""Alias for dataset_id to match Pydantic expectations."""
@@ -92,6 +100,13 @@ class RequestQueueMetadataDb(StorageMetadataDb, Base):
"""Metadata table for request queues."""
__tablename__ = 'request_queues'
+ __table_args__ = (
+ Index(
+ 'idx_buffer_lock',
+ 'request_queue_id',
+ 'buffer_locked_until',
+ ),
+ )
request_queue_id: Mapped[str] = mapped_column(String(20), nullable=False, primary_key=True)
"""Unique identifier for the request queue."""
@@ -117,6 +132,11 @@ class RequestQueueMetadataDb(StorageMetadataDb, Base):
back_populates='queue', cascade='all, delete-orphan', lazy='noload'
)
+ # Relationship to metadata buffer updates
+ buffer: Mapped[list[RequestQueueMetadataBufferDb]] = relationship(
+ back_populates='queue', cascade='all, delete-orphan', lazy='noload'
+ )
+
id = synonym('request_queue_id')
"""Alias for request_queue_id to match Pydantic expectations."""
@@ -134,6 +154,11 @@ class KeyValueStoreMetadataDb(StorageMetadataDb, Base):
back_populates='kvs', cascade='all, delete-orphan', lazy='noload'
)
+ # Relationship to metadata buffer updates
+ buffer: Mapped[list[KeyValueStoreMetadataBufferDb]] = relationship(
+ back_populates='kvs', cascade='all, delete-orphan', lazy='noload'
+ )
+
id = synonym('key_value_store_id')
"""Alias for key_value_store_id to match Pydantic expectations."""
@@ -206,7 +231,12 @@ class RequestDb(Base):
'request_queue_id',
'is_handled',
'sequence_number',
- postgresql_where=text('is_handled is false'),
+ postgresql_where=text('is_handled = false'),
+ ),
+ Index(
+ 'idx_count_aggregate',
+ 'request_queue_id',
+ 'is_handled',
),
)
@@ -218,7 +248,7 @@ class RequestDb(Base):
)
"""Foreign key to metadata request queue record."""
- data: Mapped[str] = mapped_column(String, nullable=False)
+ data: Mapped[str] = mapped_column(String(5000), nullable=False)
"""JSON-serialized Request object."""
sequence_number: Mapped[int] = mapped_column(Integer, nullable=False)
@@ -266,3 +296,90 @@ class VersionDb(Base):
__tablename__ = 'version'
version: Mapped[str] = mapped_column(String(10), nullable=False, primary_key=True)
+
+
+class MetadataBufferDb:
+ """Base model for metadata update buffer tables."""
+
+ id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
+ """Auto-increment primary key for ordering."""
+
+ # Timestamp fields - use max value when aggregating
+ accessed_at: Mapped[datetime | None] = mapped_column(AwareDateTime, nullable=False)
+ """New accessed_at timestamp, if being updated."""
+
+ modified_at: Mapped[datetime | None] = mapped_column(AwareDateTime, nullable=True)
+ """New modified_at timestamp, if being updated."""
+
+
+class KeyValueStoreMetadataBufferDb(MetadataBufferDb, Base):
+ """Buffer table for deferred key-value store metadata updates to reduce concurrent access issues."""
+
+ __tablename__ = 'key_value_store_metadata_buffer'
+
+ key_value_store_id: Mapped[str] = mapped_column(
+ String(20), ForeignKey('key_value_stores.key_value_store_id', ondelete='CASCADE'), nullable=False, index=True
+ )
+ """ID of the key-value store being updated."""
+
+ # Relationship back to key-value store metadata
+ kvs: Mapped[KeyValueStoreMetadataDb] = relationship(back_populates='buffer')
+
+ storage_id = synonym('key_value_store_id')
+ """Alias for key_value_store_id to match SqlClientMixin expectations."""
+
+
+class DatasetMetadataBufferDb(MetadataBufferDb, Base):
+ """Buffer table for deferred dataset metadata updates to reduce concurrent access issues."""
+
+ __tablename__ = 'dataset_metadata_buffer'
+
+ dataset_id: Mapped[str] = mapped_column(
+ String(20), ForeignKey('datasets.dataset_id', ondelete='CASCADE'), nullable=False, index=True
+ )
+ """ID of the dataset being updated."""
+
+ # Counter deltas - use SUM when aggregating
+ delta_item_count: Mapped[int | None] = mapped_column(Integer, nullable=True)
+ """Delta for dataset item_count."""
+
+ # Relationship back to dataset metadata
+ dataset: Mapped[DatasetMetadataDb] = relationship(back_populates='buffer')
+
+ storage_id = synonym('dataset_id')
+ """Alias for dataset_id to match SqlClientMixin expectations."""
+
+
+class RequestQueueMetadataBufferDb(MetadataBufferDb, Base):
+ """Buffer table for deferred request queue metadata updates to reduce concurrent access issues."""
+
+ __tablename__ = 'request_queue_metadata_buffer'
+
+ __table_args__ = (Index('idx_rq_client', 'request_queue_id', 'client_id'),)
+
+ request_queue_id: Mapped[str] = mapped_column(
+ String(20), ForeignKey('request_queues.request_queue_id', ondelete='CASCADE'), nullable=False, index=True
+ )
+ """ID of the request queue being updated."""
+
+ client_id: Mapped[str] = mapped_column(String(32), nullable=False)
+ """Identifier of the client making this update."""
+
+ # Counter deltas - use SUM when aggregating
+ delta_handled_count: Mapped[int | None] = mapped_column(Integer, nullable=True)
+ """Delta for handled_request_count."""
+
+ delta_pending_count: Mapped[int | None] = mapped_column(Integer, nullable=True)
+ """Delta for pending_request_count."""
+
+ delta_total_count: Mapped[int | None] = mapped_column(Integer, nullable=True)
+ """Delta for total_request_count."""
+
+ need_recalc: Mapped[bool | None] = mapped_column(Boolean, nullable=False, default=False)
+ """Flag indicating that counters need recalculation from actual data."""
+
+ # Relationship back to request queue metadata
+ queue: Mapped[RequestQueueMetadataDb] = relationship(back_populates='buffer')
+
+ storage_id = synonym('request_queue_id')
+ """Alias for request_queue_id to match SqlClientMixin expectations."""
diff --git a/src/crawlee/storage_clients/_sql/_key_value_store_client.py b/src/crawlee/storage_clients/_sql/_key_value_store_client.py
index dfa02d8014..3a06a98cc2 100644
--- a/src/crawlee/storage_clients/_sql/_key_value_store_client.py
+++ b/src/crawlee/storage_clients/_sql/_key_value_store_client.py
@@ -1,22 +1,30 @@
from __future__ import annotations
import json
+from datetime import datetime, timezone
from logging import getLogger
from typing import TYPE_CHECKING, Any, cast
from sqlalchemy import CursorResult, delete, select
+from sqlalchemy import func as sql_func
from typing_extensions import Self, override
from crawlee._utils.file import infer_mime_type
from crawlee.storage_clients._base import KeyValueStoreClient
-from crawlee.storage_clients.models import KeyValueStoreMetadata, KeyValueStoreRecord, KeyValueStoreRecordMetadata
+from crawlee.storage_clients.models import (
+ KeyValueStoreMetadata,
+ KeyValueStoreRecord,
+ KeyValueStoreRecordMetadata,
+)
from ._client_mixin import MetadataUpdateParams, SqlClientMixin
-from ._db_models import KeyValueStoreMetadataDb, KeyValueStoreRecordDb
+from ._db_models import KeyValueStoreMetadataBufferDb, KeyValueStoreMetadataDb, KeyValueStoreRecordDb
if TYPE_CHECKING:
from collections.abc import AsyncIterator
+ from sqlalchemy.ext.asyncio import AsyncSession
+
from ._storage_client import SqlStorageClient
@@ -34,6 +42,7 @@ class SqlKeyValueStoreClient(KeyValueStoreClient, SqlClientMixin):
- `key_value_stores` table: Contains store metadata (id, name, timestamps)
- `key_value_store_records` table: Contains individual key-value pairs with binary value storage, content type,
and size information
+ - `key_value_store_metadata_buffer` table: Buffers metadata updates for performance optimization
Values are serialized based on their type: JSON objects are stored as formatted JSON,
text values as UTF-8 encoded strings, and binary data as-is in the `LargeBinary` column.
@@ -57,6 +66,9 @@ class SqlKeyValueStoreClient(KeyValueStoreClient, SqlClientMixin):
_CLIENT_TYPE = 'Key-value store'
"""Human-readable client type for error messages."""
+ _BUFFER_TABLE = KeyValueStoreMetadataBufferDb
+ """SQLAlchemy model for metadata buffer."""
+
def __init__(
self,
*,
@@ -124,7 +136,8 @@ async def purge(self) -> None:
Remove all records from key_value_store_records table.
"""
- await self._purge(metadata_kwargs=MetadataUpdateParams(update_accessed_at=True, update_modified_at=True))
+ now = datetime.now(timezone.utc)
+ await self._purge(metadata_kwargs=MetadataUpdateParams(accessed_at=now, modified_at=now))
@override
async def set_value(self, *, key: str, value: Any, content_type: str | None = None) -> None:
@@ -165,9 +178,7 @@ async def set_value(self, *, key: str, value: Any, content_type: str | None = No
async with self.get_session(with_simple_commit=True) as session:
await session.execute(upsert_stmt)
- await self._update_metadata(
- session, **MetadataUpdateParams(update_accessed_at=True, update_modified_at=True)
- )
+ await self._add_buffer_record(session, update_modified_at=True)
@override
async def get_value(self, *, key: str) -> KeyValueStoreRecord | None:
@@ -175,15 +186,11 @@ async def get_value(self, *, key: str) -> KeyValueStoreRecord | None:
stmt = select(self._ITEM_TABLE).where(
self._ITEM_TABLE.key_value_store_id == self._id, self._ITEM_TABLE.key == key
)
- async with self.get_session() as session:
+ async with self.get_session(with_simple_commit=True) as session:
result = await session.execute(stmt)
record_db = result.scalar_one_or_none()
- updated = await self._update_metadata(session, **MetadataUpdateParams(update_accessed_at=True))
-
- # Commit updates to the metadata
- if updated:
- await session.commit()
+ await self._add_buffer_record(session)
if not record_db:
return None
@@ -231,11 +238,7 @@ async def delete_value(self, *, key: str) -> None:
# Update metadata if we actually deleted something
if result.rowcount > 0:
- await self._update_metadata(
- session, **MetadataUpdateParams(update_accessed_at=True, update_modified_at=True)
- )
-
- await session.commit()
+ await self._add_buffer_record(session, update_accessed_at=True, update_modified_at=True)
@override
async def iterate_keys(
@@ -259,7 +262,7 @@ async def iterate_keys(
if limit is not None:
stmt = stmt.limit(limit)
- async with self.get_session() as session:
+ async with self.get_session(with_simple_commit=True) as session:
result = await session.stream(stmt.execution_options(stream_results=True))
async for row in result:
@@ -269,26 +272,18 @@ async def iterate_keys(
size=row.size,
)
- updated = await self._update_metadata(session, **MetadataUpdateParams(update_accessed_at=True))
-
- # Commit updates to the metadata
- if updated:
- await session.commit()
+ await self._add_buffer_record(session)
@override
async def record_exists(self, *, key: str) -> bool:
stmt = select(self._ITEM_TABLE.key).where(
self._ITEM_TABLE.key_value_store_id == self._id, self._ITEM_TABLE.key == key
)
- async with self.get_session() as session:
+ async with self.get_session(with_simple_commit=True) as session:
# Check if record exists
result = await session.execute(stmt)
- updated = await self._update_metadata(session, **MetadataUpdateParams(update_accessed_at=True))
-
- # Commit updates to the metadata
- if updated:
- await session.commit()
+ await self._add_buffer_record(session)
return result.scalar_one_or_none() is not None
@@ -296,5 +291,36 @@ async def record_exists(self, *, key: str) -> bool:
async def get_public_url(self, *, key: str) -> str:
raise NotImplementedError('Public URLs are not supported for SQL key-value stores.')
+ @override
def _specific_update_metadata(self, **_kwargs: dict[str, Any]) -> dict[str, Any]:
return {}
+
+ @override
+ def _prepare_buffer_data(self, **_kwargs: Any) -> dict[str, Any]:
+ """Prepare key-value store specific buffer data.
+
+ For KeyValueStore, we don't have specific metadata fields to track in buffer,
+ so we just return empty dict. The base buffer will handle accessed_at/modified_at.
+ """
+ return {}
+
+ @override
+ async def _apply_buffer_updates(self, session: AsyncSession, max_buffer_id: int) -> None:
+ aggregation_stmt = select(
+ sql_func.max(self._BUFFER_TABLE.accessed_at).label('max_accessed_at'),
+ sql_func.max(self._BUFFER_TABLE.modified_at).label('max_modified_at'),
+ ).where(self._BUFFER_TABLE.storage_id == self._id, self._BUFFER_TABLE.id <= max_buffer_id)
+
+ result = await session.execute(aggregation_stmt)
+ row = result.first()
+
+ if not row:
+ return
+
+ await self._update_metadata(
+ session,
+ **MetadataUpdateParams(
+ accessed_at=row.max_accessed_at,
+ modified_at=row.max_modified_at,
+ ),
+ )
diff --git a/src/crawlee/storage_clients/_sql/_request_queue_client.py b/src/crawlee/storage_clients/_sql/_request_queue_client.py
index f5a320bb21..cba4f18702 100644
--- a/src/crawlee/storage_clients/_sql/_request_queue_client.py
+++ b/src/crawlee/storage_clients/_sql/_request_queue_client.py
@@ -7,7 +7,8 @@
from logging import getLogger
from typing import TYPE_CHECKING, Any, cast
-from sqlalchemy import CursorResult, func, or_, select, update
+from sqlalchemy import CursorResult, exists, func, or_, select, update
+from sqlalchemy import func as sql_func
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import load_only
from typing_extensions import NotRequired, Self, override
@@ -23,12 +24,13 @@
)
from ._client_mixin import MetadataUpdateParams, SqlClientMixin
-from ._db_models import RequestDb, RequestQueueMetadataDb, RequestQueueStateDb
+from ._db_models import RequestDb, RequestQueueMetadataBufferDb, RequestQueueMetadataDb, RequestQueueStateDb
if TYPE_CHECKING:
from collections.abc import Sequence
from sqlalchemy.ext.asyncio import AsyncSession
+ from sqlalchemy.sql import ColumnElement
from ._storage_client import SqlStorageClient
@@ -44,6 +46,7 @@ class _QueueMetadataUpdateParams(MetadataUpdateParams):
new_total_request_count: NotRequired[int]
delta_handled_request_count: NotRequired[int]
delta_pending_request_count: NotRequired[int]
+ delta_total_request_count: NotRequired[int]
recalculate: NotRequired[bool]
update_had_multiple_clients: NotRequired[bool]
@@ -64,6 +67,7 @@ class SqlRequestQueueClient(RequestQueueClient, SqlClientMixin):
- `request_queue_records` table: Contains individual requests with JSON data, unique keys for deduplication,
sequence numbers for ordering, and processing status flags
- `request_queue_state` table: Maintains counters for sequence numbers to ensure proper ordering of requests.
+ - `request_queue_metadata_buffer` table: Buffers metadata updates for performance optimization
Requests are serialized to JSON for storage and maintain proper ordering through sequence
numbers. The implementation provides concurrent access safety through transaction
@@ -93,6 +97,9 @@ class SqlRequestQueueClient(RequestQueueClient, SqlClientMixin):
"""Number of seconds for which a request is considered blocked in the database after being fetched for processing.
"""
+ _BUFFER_TABLE = RequestQueueMetadataBufferDb
+ """SQLAlchemy model for metadata buffer."""
+
def __init__(
self,
*,
@@ -111,6 +118,9 @@ def __init__(
self.client_key = crypto_random_object_id(length=32)[:32]
"""Unique identifier for this client instance."""
+ self._had_multiple_clients = False
+ """Indicates whether the queue has been accessed by multiple clients."""
+
@classmethod
async def open(
cls,
@@ -155,7 +165,9 @@ async def open(
@override
async def get_metadata(self) -> RequestQueueMetadata:
# The database is a single place of truth
- return await self._get_metadata(RequestQueueMetadata)
+ metadata = await self._get_metadata(RequestQueueMetadata)
+ self._had_multiple_clients = metadata.had_multiple_clients
+ return metadata
@override
async def drop(self) -> None:
@@ -174,12 +186,12 @@ async def purge(self) -> None:
Resets pending_request_count and handled_request_count to 0 and deletes all records from request_queue_records
table.
"""
+ now = datetime.now(timezone.utc)
await self._purge(
metadata_kwargs=_QueueMetadataUpdateParams(
- update_accessed_at=True,
- update_modified_at=True,
+ accessed_at=now,
+ modified_at=now,
new_pending_request_count=0,
- force=True,
)
)
@@ -202,7 +214,7 @@ async def add_batch_of_requests(
transaction_processed_requests = []
transaction_processed_requests_unique_keys = set()
- metadata_recalculate = False
+ approximate_new_request = 0
# Deduplicate requests by unique_key upfront
unique_requests = {}
@@ -254,7 +266,6 @@ async def add_batch_of_requests(
state.sequence_counter += 1
insert_values.append(value)
- metadata_recalculate = True
transaction_processed_requests.append(
ProcessedRequest(
unique_key=request.unique_key,
@@ -332,20 +343,20 @@ async def add_batch_of_requests(
update_columns=['sequence_number'],
conflict_cols=['request_id', 'request_queue_id'],
)
- await session.execute(upsert_stmt)
+ result = await session.execute(upsert_stmt)
else:
# If the request already exists in the database, we ignore this request when inserting.
insert_stmt_with_ignore = self._build_insert_stmt_with_ignore(self._ITEM_TABLE, insert_values)
- await session.execute(insert_stmt_with_ignore)
+ result = await session.execute(insert_stmt_with_ignore)
+
+ result = cast('CursorResult', result) if not isinstance(result, CursorResult) else result
+ approximate_new_request += result.rowcount
- await self._update_metadata(
+ await self._add_buffer_record(
session,
- **_QueueMetadataUpdateParams(
- recalculate=metadata_recalculate,
- update_modified_at=True,
- update_accessed_at=True,
- force=metadata_recalculate,
- ),
+ update_modified_at=True,
+ delta_pending_request_count=approximate_new_request,
+ delta_total_request_count=approximate_new_request,
)
try:
@@ -354,8 +365,10 @@ async def add_batch_of_requests(
except SQLAlchemyError as e:
await session.rollback()
logger.warning(f'Failed to commit session: {e}')
- await self._update_metadata(
- session, recalculate=True, update_modified_at=True, update_accessed_at=True, force=True
+ await self._add_buffer_record(
+ session,
+ update_modified_at=True,
+ reclaculate=True,
)
await session.commit()
transaction_processed_requests.clear()
@@ -383,7 +396,7 @@ async def get_request(self, unique_key: str) -> Request | None:
stmt = select(self._ITEM_TABLE).where(
self._ITEM_TABLE.request_queue_id == self._id, self._ITEM_TABLE.request_id == request_id
)
- async with self.get_session() as session:
+ async with self.get_session(with_simple_commit=True) as session:
result = await session.execute(stmt)
request_db = result.scalar_one_or_none()
@@ -391,11 +404,7 @@ async def get_request(self, unique_key: str) -> Request | None:
logger.warning(f'Request with ID "{unique_key}" not found in the queue.')
return None
- updated = await self._update_metadata(session, update_accessed_at=True)
-
- # Commit updates to the metadata
- if updated:
- await session.commit()
+ await self._add_buffer_record(session)
return Request.model_validate_json(request_db.data)
@@ -413,14 +422,14 @@ async def fetch_next_request(self) -> Request | None:
select(self._ITEM_TABLE)
.where(
self._ITEM_TABLE.request_queue_id == self._id,
- self._ITEM_TABLE.is_handled.is_(False),
+ self._ITEM_TABLE.is_handled == False, # noqa: E712
or_(self._ITEM_TABLE.time_blocked_until.is_(None), self._ITEM_TABLE.time_blocked_until < now),
)
.order_by(self._ITEM_TABLE.sequence_number.asc())
.limit(self._MAX_BATCH_FETCH_SIZE)
)
- async with self.get_session() as session:
+ async with self.get_session(with_simple_commit=True) as session:
# We use the `skip_locked` database mechanism to prevent the 'interception' of requests by another client
if dialect == 'postgresql':
stmt = stmt.with_for_update(skip_locked=True)
@@ -456,7 +465,7 @@ async def fetch_next_request(self) -> Request | None:
.where(
self._ITEM_TABLE.request_queue_id == self._id,
self._ITEM_TABLE.request_id.in_(request_ids),
- self._ITEM_TABLE.is_handled.is_(False),
+ self._ITEM_TABLE.is_handled == False, # noqa: E712
or_(self._ITEM_TABLE.time_blocked_until.is_(None), self._ITEM_TABLE.time_blocked_until < now),
)
.values(time_blocked_until=block_until, client_key=self.client_key)
@@ -470,9 +479,7 @@ async def fetch_next_request(self) -> Request | None:
await session.rollback()
return None
- await self._update_metadata(session, **_QueueMetadataUpdateParams(update_accessed_at=True))
-
- await session.commit()
+ await self._add_buffer_record(session)
requests = [Request.model_validate_json(r.data) for r in requests_db if r.request_id in blocked_ids]
@@ -497,7 +504,7 @@ async def mark_request_as_handled(self, request: Request) -> ProcessedRequest |
.where(self._ITEM_TABLE.request_queue_id == self._id, self._ITEM_TABLE.request_id == request_id)
.values(is_handled=True, time_blocked_until=None, client_key=None, data=request.model_dump_json())
)
- async with self.get_session() as session:
+ async with self.get_session(with_simple_commit=True) as session:
result = await session.execute(stmt)
result = cast('CursorResult', result) if not isinstance(result, CursorResult) else result
@@ -505,17 +512,9 @@ async def mark_request_as_handled(self, request: Request) -> ProcessedRequest |
logger.warning(f'Request {request.unique_key} not found in database.')
return None
- await self._update_metadata(
- session,
- **_QueueMetadataUpdateParams(
- delta_handled_request_count=1,
- delta_pending_request_count=-1,
- update_modified_at=True,
- update_accessed_at=True,
- force=True,
- ),
+ await self._add_buffer_record(
+ session, update_modified_at=True, delta_pending_request_count=-1, delta_handled_request_count=1
)
- await session.commit()
return ProcessedRequest(
unique_key=request.unique_key,
was_already_present=True,
@@ -567,9 +566,7 @@ async def reclaim_request(
if result.rowcount == 0:
logger.warning(f'Request {request.unique_key} not found in database.')
return None
- await self._update_metadata(
- session, **_QueueMetadataUpdateParams(update_modified_at=True, update_accessed_at=True)
- )
+ await self._add_buffer_record(session, update_modified_at=True)
# put the forefront request at the beginning of the cache
if forefront:
@@ -587,30 +584,43 @@ async def is_empty(self) -> bool:
if self._pending_fetch_cache:
return False
- # Check database for unhandled requests
- async with self.get_session() as session:
- metadata_orm = await session.get(self._METADATA_TABLE, self._id)
- if not metadata_orm:
- raise ValueError(f'Request queue with ID "{self._id}" not found.')
+ metadata = await self.get_metadata()
- empty = metadata_orm.pending_request_count == 0
+ async with self.get_session(with_simple_commit=True) as session:
+ # If there are no pending requests, check if there are any buffered updates
+ if metadata.pending_request_count == 0:
+ # Check for active buffer lock (indicates pending buffer processing)
+ buffer_lock_stmt = select(self._METADATA_TABLE.buffer_locked_until).where(
+ self._METADATA_TABLE.id == self._id
+ )
+ buffer_lock_result = await session.execute(buffer_lock_stmt)
+ buffer_locked_until = buffer_lock_result.scalar()
+
+ # If buffer is locked, there are pending updates being processed
+ if buffer_locked_until is not None:
+ await self._add_buffer_record(session)
+ return False
+
+ # Check if there are any buffered updates that might change the pending count
+ buffer_check_stmt = select(
+ exists().where(
+ (self._BUFFER_TABLE.storage_id == self._id)
+ & (
+ (self._BUFFER_TABLE.delta_pending_count != 0) | (self._BUFFER_TABLE.need_recalc == True) # noqa: E712
+ )
+ )
+ )
+ buffer_result = await session.execute(buffer_check_stmt)
+ has_pending_buffer_updates = buffer_result.scalar()
- updated = await self._update_metadata(
- session,
- **_QueueMetadataUpdateParams(
- update_accessed_at=True,
- # With multi-client access, counters may become out of sync.
- # If the queue is not empty, we perform a recalculation to synchronize the counters in the metadata.
- recalculate=not empty,
- update_modified_at=not empty,
- ),
- )
+ await self._add_buffer_record(session)
+ # If there are no pending requests and no buffered updates, the queue is empty
+ return not has_pending_buffer_updates
- # Commit updates to the metadata
- if updated:
- await session.commit()
+ # There are pending requests (may be inaccurate), ensure recalculated metadata
+ await self._add_buffer_record(session, update_modified_at=True, recalculate=True)
- return empty
+ return False
async def _get_state(self, session: AsyncSession) -> RequestQueueStateDb:
"""Get the current state of the request queue."""
@@ -628,6 +638,7 @@ async def _get_state(self, session: AsyncSession) -> RequestQueueStateDb:
raise RuntimeError(f'Failed to create or retrieve state for queue {self._id}')
return orm_state
+ @override
def _specific_update_metadata(
self,
new_handled_request_count: int | None = None,
@@ -635,6 +646,7 @@ def _specific_update_metadata(
new_total_request_count: int | None = None,
delta_handled_request_count: int | None = None,
delta_pending_request_count: int | None = None,
+ delta_total_request_count: int | None = None,
*,
recalculate: bool = False,
update_had_multiple_clients: bool = False,
@@ -649,6 +661,7 @@ def _specific_update_metadata(
new_total_request_count: If provided, update the total_request_count to this value.
delta_handled_request_count: If provided, add this value to the handled_request_count.
delta_pending_request_count: If provided, add this value to the pending_request_count.
+ delta_total_request_count: If provided, add this value to the total_request_count.
recalculate: If True, recalculate the pending_request_count, and total_request_count on request table.
update_had_multiple_clients: If True, set had_multiple_clients to True.
"""
@@ -657,23 +670,6 @@ def _specific_update_metadata(
if update_had_multiple_clients:
values_to_set['had_multiple_clients'] = True
- if new_handled_request_count is not None:
- values_to_set['handled_request_count'] = new_handled_request_count
- elif delta_handled_request_count is not None:
- values_to_set['handled_request_count'] = (
- self._METADATA_TABLE.handled_request_count + delta_handled_request_count
- )
-
- if new_pending_request_count is not None:
- values_to_set['pending_request_count'] = new_pending_request_count
- elif delta_pending_request_count is not None:
- values_to_set['pending_request_count'] = (
- self._METADATA_TABLE.pending_request_count + delta_pending_request_count
- )
-
- if new_total_request_count is not None:
- values_to_set['total_request_count'] = new_total_request_count
-
if recalculate:
stmt = (
update(self._METADATA_TABLE)
@@ -702,6 +698,28 @@ def _specific_update_metadata(
values_to_set['custom_stmt'] = stmt
+ else:
+ if new_handled_request_count is not None:
+ values_to_set['handled_request_count'] = new_handled_request_count
+ elif delta_handled_request_count is not None:
+ values_to_set['handled_request_count'] = (
+ self._METADATA_TABLE.handled_request_count + delta_handled_request_count
+ )
+
+ if new_pending_request_count is not None:
+ values_to_set['pending_request_count'] = new_pending_request_count
+ elif delta_pending_request_count is not None:
+ values_to_set['pending_request_count'] = (
+ self._METADATA_TABLE.pending_request_count + delta_pending_request_count
+ )
+
+ if new_total_request_count is not None:
+ values_to_set['total_request_count'] = new_total_request_count
+ elif delta_total_request_count is not None:
+ values_to_set['total_request_count'] = (
+ self._METADATA_TABLE.total_request_count + delta_total_request_count
+ )
+
return values_to_set
@staticmethod
@@ -718,3 +736,82 @@ def _get_int_id_from_unique_key(unique_key: str) -> int:
hashed_key = sha256(unique_key.encode('utf-8')).hexdigest()
name_length = 15
return int(hashed_key[:name_length], 16)
+
+ @override
+ def _prepare_buffer_data(
+ self,
+ delta_handled_request_count: int | None = None,
+ delta_pending_request_count: int | None = None,
+ delta_total_request_count: int | None = None,
+ *,
+ recalculate: bool = False,
+ **_kwargs: Any,
+ ) -> dict[str, Any]:
+ """Prepare request queue specific buffer data.
+
+ Args:
+ delta_handled_request_count: If provided, add this value to the handled_request_count.
+ delta_pending_request_count: If provided, add this value to the pending_request_count.
+ delta_total_request_count: If provided, add this value to the total_request_count.
+ recalculate: If True, recalculate the pending_request_count, and total_request_count on request table.
+ """
+ buffer_data: dict[str, Any] = {
+ 'client_id': self.client_key,
+ }
+
+ if delta_handled_request_count:
+ buffer_data['delta_handled_count'] = delta_handled_request_count
+
+ if delta_pending_request_count:
+ buffer_data['delta_pending_count'] = delta_pending_request_count
+
+ if delta_total_request_count:
+ buffer_data['delta_total_count'] = delta_total_request_count
+
+ if recalculate:
+ buffer_data['need_recalc'] = True
+
+ return buffer_data
+
+ @override
+ async def _apply_buffer_updates(self, session: AsyncSession, max_buffer_id: int) -> None:
+ aggregations: list[ColumnElement[Any]] = [
+ sql_func.max(self._BUFFER_TABLE.accessed_at).label('max_accessed_at'),
+ sql_func.max(self._BUFFER_TABLE.modified_at).label('max_modified_at'),
+ sql_func.sum(self._BUFFER_TABLE.delta_handled_count).label('delta_handled_count'),
+ sql_func.sum(self._BUFFER_TABLE.delta_pending_count).label('delta_pending_count'),
+ sql_func.sum(self._BUFFER_TABLE.delta_total_count).label('delta_total_count'),
+ ]
+
+ if not self._had_multiple_clients:
+ aggregations.append(
+ sql_func.count(sql_func.distinct(self._BUFFER_TABLE.client_id)).label('unique_clients_count')
+ )
+
+ if self._storage_client.get_dialect_name() == 'postgresql':
+ aggregations.append(sql_func.bool_or(self._BUFFER_TABLE.need_recalc).label('need_recalc'))
+ else:
+ aggregations.append(sql_func.max(self._BUFFER_TABLE.need_recalc).label('need_recalc'))
+
+ aggregation_stmt = select(*aggregations).where(
+ self._BUFFER_TABLE.storage_id == self._id, self._BUFFER_TABLE.id <= max_buffer_id
+ )
+
+ result = await session.execute(aggregation_stmt)
+ row = result.first()
+
+ if not row:
+ return
+
+ await self._update_metadata(
+ session,
+ **_QueueMetadataUpdateParams(
+ accessed_at=row.max_accessed_at,
+ modified_at=row.max_modified_at,
+ update_had_multiple_clients=not self._had_multiple_clients and row.unique_clients_count > 1,
+ delta_handled_request_count=row.delta_handled_count,
+ delta_pending_request_count=row.delta_pending_count,
+ delta_total_request_count=row.delta_total_count,
+ recalculate=row.need_recalc,
+ ),
+ )
diff --git a/src/crawlee/storage_clients/_sql/_storage_client.py b/src/crawlee/storage_clients/_sql/_storage_client.py
index d324a17a86..54209d4302 100644
--- a/src/crawlee/storage_clients/_sql/_storage_client.py
+++ b/src/crawlee/storage_clients/_sql/_storage_client.py
@@ -2,7 +2,7 @@
import sys
import warnings
-from datetime import timedelta
+from logging import getLogger
from pathlib import Path
from typing import TYPE_CHECKING
@@ -26,6 +26,9 @@
from sqlalchemy.ext.asyncio import AsyncSession
+logger = getLogger(__name__)
+
+
@docs_group('Storage clients')
class SqlStorageClient(StorageClient):
"""SQL implementation of the storage client.
@@ -68,9 +71,6 @@ def __init__(
self._initialized = False
self.session_maker: None | async_sessionmaker[AsyncSession] = None
- # Minimum interval to reduce database load from frequent concurrent metadata updates
- self._accessed_modified_update_interval = timedelta(seconds=1)
-
# Flag needed to apply optimizations only for default database
self._default_flag = self._engine is None and self._connection_string is None
self._dialect_name: str | None = None
@@ -106,10 +106,6 @@ def get_dialect_name(self) -> str | None:
"""Get the database dialect name."""
return self._dialect_name
- def get_accessed_modified_update_interval(self) -> timedelta:
- """Get the interval for accessed and modified updates."""
- return self._accessed_modified_update_interval
-
async def initialize(self, configuration: Configuration) -> None:
"""Initialize the database schema.
@@ -140,9 +136,7 @@ async def initialize(self, configuration: Configuration) -> None:
await conn.execute(text('PRAGMA mmap_size=268435456')) # 256MB memory mapping
await conn.execute(text('PRAGMA foreign_keys=ON')) # Enforce constraints
await conn.execute(text('PRAGMA busy_timeout=30000')) # 30s busy timeout
-
await conn.run_sync(Base.metadata.create_all, checkfirst=True)
-
from crawlee import __version__ # Noqa: PLC0415
db_version = (await conn.execute(select(VersionDb))).scalar_one_or_none()
@@ -158,7 +152,6 @@ async def initialize(self, configuration: Configuration) -> None:
)
elif not db_version:
await conn.execute(insert(VersionDb).values(version=__version__))
-
except (IntegrityError, OperationalError):
await conn.rollback()
@@ -280,10 +273,9 @@ def _get_or_create_engine(self, configuration: Configuration) -> AsyncEngine:
self._engine = create_async_engine(
connection_string,
future=True,
- pool_size=5,
- max_overflow=10,
- pool_timeout=30,
- pool_recycle=600,
+ pool_size=10,
+ max_overflow=50,
+ pool_timeout=60,
pool_pre_ping=True,
echo=False,
connect_args={'timeout': 30},
diff --git a/tests/unit/storage_clients/_sql/test_sql_dataset_client.py b/tests/unit/storage_clients/_sql/test_sql_dataset_client.py
index 5ad4448d4c..f802b35c62 100644
--- a/tests/unit/storage_clients/_sql/test_sql_dataset_client.py
+++ b/tests/unit/storage_clients/_sql/test_sql_dataset_client.py
@@ -1,7 +1,6 @@
from __future__ import annotations
import asyncio
-from datetime import timedelta
from typing import TYPE_CHECKING
import pytest
@@ -38,11 +37,9 @@ def get_tables(sync_conn: Connection) -> list[str]:
@pytest.fixture
async def dataset_client(
configuration: Configuration,
- monkeypatch: pytest.MonkeyPatch,
) -> AsyncGenerator[SqlDatasetClient, None]:
"""A fixture for a SQL dataset client."""
async with SqlStorageClient() as storage_client:
- monkeypatch.setattr(storage_client, '_accessed_modified_update_interval', timedelta(seconds=0))
client = await storage_client.create_dataset_client(
name='test-dataset',
configuration=configuration,
diff --git a/tests/unit/storage_clients/_sql/test_sql_kvs_client.py b/tests/unit/storage_clients/_sql/test_sql_kvs_client.py
index 89ecc891c4..69a71b0aa8 100644
--- a/tests/unit/storage_clients/_sql/test_sql_kvs_client.py
+++ b/tests/unit/storage_clients/_sql/test_sql_kvs_client.py
@@ -2,7 +2,6 @@
import asyncio
import json
-from datetime import timedelta
from typing import TYPE_CHECKING
import pytest
@@ -34,16 +33,13 @@ def configuration(tmp_path: Path) -> Configuration:
@pytest.fixture
async def kvs_client(
configuration: Configuration,
- monkeypatch: pytest.MonkeyPatch,
) -> AsyncGenerator[SqlKeyValueStoreClient, None]:
"""A fixture for a SQL key-value store client."""
async with SqlStorageClient() as storage_client:
- monkeypatch.setattr(storage_client, '_accessed_modified_update_interval', timedelta(seconds=0))
client = await storage_client.create_kvs_client(
name='test-kvs',
configuration=configuration,
)
- monkeypatch.setattr(client, '_accessed_modified_update_interval', timedelta(seconds=0))
yield client
await client.drop()
diff --git a/tests/unit/storage_clients/_sql/test_sql_rq_client.py b/tests/unit/storage_clients/_sql/test_sql_rq_client.py
index c98b7a1fc0..1139d30bbf 100644
--- a/tests/unit/storage_clients/_sql/test_sql_rq_client.py
+++ b/tests/unit/storage_clients/_sql/test_sql_rq_client.py
@@ -2,7 +2,6 @@
import asyncio
import json
-from datetime import timedelta
from typing import TYPE_CHECKING
import pytest
@@ -35,16 +34,13 @@ def configuration(tmp_path: Path) -> Configuration:
@pytest.fixture
async def rq_client(
configuration: Configuration,
- monkeypatch: pytest.MonkeyPatch,
) -> AsyncGenerator[SqlRequestQueueClient, None]:
"""A fixture for a SQL request queue client."""
async with SqlStorageClient() as storage_client:
- monkeypatch.setattr(storage_client, '_accessed_modified_update_interval', timedelta(seconds=0))
client = await storage_client.create_rq_client(
name='test-request-queue',
configuration=configuration,
)
- monkeypatch.setattr(client, '_accessed_modified_update_interval', timedelta(seconds=0))
yield client
await client.drop()