diff --git a/pyproject.toml b/pyproject.toml index 2de869c8..38cf25e6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,7 @@ keywords = [ python = "^3.9" apify-shared = ">=1.1.2" httpx = ">=0.25.0" +more_itertools = ">=10.0.0" [tool.poetry.group.dev.dependencies] build = "~1.2.0" diff --git a/scripts/check_async_docstrings.py b/scripts/check_async_docstrings.py index 0b5ee6d0..cbeed8c1 100755 --- a/scripts/check_async_docstrings.py +++ b/scripts/check_async_docstrings.py @@ -36,7 +36,7 @@ continue # If the sync method has a docstring, check if it matches the async dostring - if isinstance(sync_method.value[0].value, str): + if sync_method and isinstance(sync_method.value[0].value, str): sync_docstring = sync_method.value[0].value async_docstring = async_method.value[0].value expected_docstring = sync_to_async_docstring(sync_docstring) diff --git a/src/apify_client/clients/resource_clients/request_queue.py b/src/apify_client/clients/resource_clients/request_queue.py index 8405d4b2..d17e725a 100644 --- a/src/apify_client/clients/resource_clients/request_queue.py +++ b/src/apify_client/clients/resource_clients/request_queue.py @@ -1,13 +1,55 @@ from __future__ import annotations -from typing import Any +import asyncio +import logging +import math +from dataclasses import dataclass +from datetime import timedelta +from queue import Queue +from time import sleep +from typing import TYPE_CHECKING, Any, TypedDict from apify_shared.utils import filter_out_none_values_recursively, ignore_docs, parse_date_fields +from more_itertools import constrained_batches from apify_client._errors import ApifyApiError from apify_client._utils import catch_not_found_or_throw, pluck_data from apify_client.clients.base import ResourceClient, ResourceClientAsync +if TYPE_CHECKING: + from collections.abc import Iterable + +logger = logging.getLogger(__name__) + +_RQ_MAX_REQUESTS_PER_BATCH = 25 +_MAX_PAYLOAD_SIZE_BYTES = 9 * 1024 * 1024 # 9 MB +_SAFETY_BUFFER_PERCENT = 0.01 / 100 # 0.01% + + +class BatchAddRequestsResult(TypedDict): + """Result of the batch add requests operation. + + Args: + processedRequests: List of successfully added requests. + unprocessedRequests: List of requests that failed to be added. + """ + + processedRequests: list[dict] + unprocessedRequests: list[dict] + + +@dataclass +class AddRequestsBatch: + """Batch of requests to add to the request queue. + + Args: + requests: List of requests to be added to the request queue. + num_of_retries: Number of times this batch has been retried. + """ + + requests: Iterable[dict] + num_of_retries: int = 0 + class RequestQueueClient(ResourceClient): """Sub-client for manipulating a single request queue.""" @@ -240,28 +282,84 @@ def delete_request_lock(self: RequestQueueClient, request_id: str, *, forefront: ) def batch_add_requests( - self: RequestQueueClient, + self, requests: list[dict], *, - forefront: bool | None = None, - ) -> dict: - """Add requests to the queue. + forefront: bool = False, + max_parallel: int = 1, + max_unprocessed_requests_retries: int = 3, + min_delay_between_unprocessed_requests_retries: timedelta = timedelta(milliseconds=500), + ) -> BatchAddRequestsResult: + """Add requests to the request queue in batches. + + Requests are split into batches based on size and processed in parallel. https://docs.apify.com/api/v2#/reference/request-queues/batch-request-operations/add-requests Args: - requests (list[dict]): list of the requests to add - forefront (bool, optional): Whether to add the requests to the head or the end of the queue + requests: List of requests to be added to the queue. + forefront: Whether to add requests to the front of the queue. + max_parallel: Specifies the maximum number of parallel tasks for API calls. This is only applicable + to the async client. For the sync client, this value must be set to 1, as parallel execution + is not supported. + max_unprocessed_requests_retries: Number of retry attempts for unprocessed requests. + min_delay_between_unprocessed_requests_retries: Minimum delay between retry attempts for unprocessed requests. + + Returns: + Result containing lists of processed and unprocessed requests. """ + if max_parallel != 1: + raise NotImplementedError('max_parallel is only supported in async client') + request_params = self._params(clientKey=self.client_key, forefront=forefront) - response = self.http_client.call( - url=self._url('requests/batch'), - method='POST', - params=request_params, - json=requests, + # Compute the payload size limit to ensure it doesn't exceed the maximum allowed size. + payload_size_limit_bytes = _MAX_PAYLOAD_SIZE_BYTES - math.ceil(_MAX_PAYLOAD_SIZE_BYTES * _SAFETY_BUFFER_PERCENT) + + # Split the requests into batches, constrained by the max payload size and max requests per batch. + batches = constrained_batches( + requests, + max_size=payload_size_limit_bytes, + max_count=_RQ_MAX_REQUESTS_PER_BATCH, ) - return parse_date_fields(pluck_data(response.json())) + + # Put the batches into the queue for processing. + queue = Queue[AddRequestsBatch]() + + for b in batches: + queue.put(AddRequestsBatch(b)) + + processed_requests = list[dict]() + unprocessed_requests = list[dict]() + + # Process all batches in the queue sequentially. + while not queue.empty(): + batch = queue.get() + + # Send the batch to the API. + response = self.http_client.call( + url=self._url('requests/batch'), + method='POST', + params=request_params, + json=list(batch.requests), + ) + + # Retry if the request failed and the retry limit has not been reached. + if not response.is_success and batch.num_of_retries < max_unprocessed_requests_retries: + batch.num_of_retries += 1 + sleep(min_delay_between_unprocessed_requests_retries.total_seconds()) + queue.put(batch) + + # Otherwise, add the processed/unprocessed requests to their respective lists. + else: + response_parsed = parse_date_fields(pluck_data(response.json())) + processed_requests.extend(response_parsed.get('processedRequests', [])) + unprocessed_requests.extend(response_parsed.get('unprocessedRequests', [])) + + return { + 'processedRequests': processed_requests, + 'unprocessedRequests': unprocessed_requests, + } def batch_delete_requests(self: RequestQueueClient, requests: list[dict]) -> dict: """Delete given requests from the queue. @@ -540,29 +638,139 @@ async def delete_request_lock( params=request_params, ) + async def _batch_add_requests_worker( + self, + queue: asyncio.Queue[AddRequestsBatch], + request_params: dict, + max_unprocessed_requests_retries: int, + min_delay_between_unprocessed_requests_retries: timedelta, + ) -> BatchAddRequestsResult: + """Worker function to process a batch of requests. + + This worker will process batches from the queue, retrying requests that fail until the retry limit is reached. + + Returns result containing lists of processed and unprocessed requests by the worker. + """ + processed_requests = list[dict]() + unprocessed_requests = list[dict]() + + while True: + # Get the next batch from the queue. + try: + batch = await queue.get() + except asyncio.CancelledError: + break + + try: + # Send the batch to the API. + response = await self.http_client.call( + url=self._url('requests/batch'), + method='POST', + params=request_params, + json=list(batch.requests), + ) + + response_parsed = parse_date_fields(pluck_data(response.json())) + + # Retry if the request failed and the retry limit has not been reached. + if not response.is_success and batch.num_of_retries < max_unprocessed_requests_retries: + batch.num_of_retries += 1 + await asyncio.sleep(min_delay_between_unprocessed_requests_retries.total_seconds()) + await queue.put(batch) + + # Otherwise, add the processed/unprocessed requests to their respective lists. + else: + processed_requests.extend(response_parsed.get('processedRequests', [])) + unprocessed_requests.extend(response_parsed.get('unprocessedRequests', [])) + + except Exception as exc: + logger.warning(f'Error occurred while processing a batch of requests: {exc}') + + finally: + # Mark the batch as done whether it succeeded or failed. + queue.task_done() + + return { + 'processedRequests': processed_requests, + 'unprocessedRequests': unprocessed_requests, + } + async def batch_add_requests( - self: RequestQueueClientAsync, + self, requests: list[dict], *, - forefront: bool | None = None, - ) -> dict: - """Add requests to the queue. + forefront: bool = False, + max_parallel: int = 5, + max_unprocessed_requests_retries: int = 3, + min_delay_between_unprocessed_requests_retries: timedelta = timedelta(milliseconds=500), + ) -> BatchAddRequestsResult: + """Add requests to the request queue in batches. + + Requests are split into batches based on size and processed in parallel. https://docs.apify.com/api/v2#/reference/request-queues/batch-request-operations/add-requests Args: - requests (list[dict]): list of the requests to add - forefront (bool, optional): Whether to add the requests to the head or the end of the queue + requests: List of requests to be added to the queue. + forefront: Whether to add requests to the front of the queue. + max_parallel: Specifies the maximum number of parallel tasks for API calls. This is only applicable + to the async client. For the sync client, this value must be set to 1, as parallel execution + is not supported. + max_unprocessed_requests_retries: Number of retry attempts for unprocessed requests. + min_delay_between_unprocessed_requests_retries: Minimum delay between retry attempts for unprocessed requests. + + Returns: + Result containing lists of processed and unprocessed requests. """ + tasks = set[asyncio.Task]() + queue: asyncio.Queue[AddRequestsBatch] = asyncio.Queue() request_params = self._params(clientKey=self.client_key, forefront=forefront) - response = await self.http_client.call( - url=self._url('requests/batch'), - method='POST', - params=request_params, - json=requests, + # Compute the payload size limit to ensure it doesn't exceed the maximum allowed size. + payload_size_limit_bytes = _MAX_PAYLOAD_SIZE_BYTES - math.ceil(_MAX_PAYLOAD_SIZE_BYTES * _SAFETY_BUFFER_PERCENT) + + # Split the requests into batches, constrained by the max payload size and max requests per batch. + batches = constrained_batches( + requests, + max_size=payload_size_limit_bytes, + max_count=_RQ_MAX_REQUESTS_PER_BATCH, ) - return parse_date_fields(pluck_data(response.json())) + + for batch in batches: + await queue.put(AddRequestsBatch(batch)) + + # Start a required number of worker tasks to process the batches. + for i in range(max_parallel): + coro = self._batch_add_requests_worker( + queue, + request_params, + max_unprocessed_requests_retries, + min_delay_between_unprocessed_requests_retries, + ) + task = asyncio.create_task(coro, name=f'batch_add_requests_worker_{i}') + tasks.add(task) + + # Wait for all batches to be processed. + await queue.join() + + # Send cancellation signals to all worker tasks and wait for them to finish. + for task in tasks: + task.cancel() + + results: list[BatchAddRequestsResult] = await asyncio.gather(*tasks) + + # Combine the results from all workers and return them. + processed_requests = [] + unprocessed_requests = [] + + for result in results: + processed_requests.extend(result['processedRequests']) + unprocessed_requests.extend(result['unprocessedRequests']) + + return { + 'processedRequests': processed_requests, + 'unprocessedRequests': unprocessed_requests, + } async def batch_delete_requests(self: RequestQueueClientAsync, requests: list[dict]) -> dict: """Delete given requests from the queue.