|
1 | 1 | from __future__ import annotations
|
2 | 2 |
|
3 | 3 | import asyncio
|
| 4 | +import logging |
4 | 5 | import math
|
5 | 6 | from datetime import timedelta
|
6 |
| -from typing import Any |
| 7 | +from typing import Any, TypedDict |
7 | 8 |
|
8 | 9 | from apify_shared.utils import filter_out_none_values_recursively, ignore_docs, parse_date_fields
|
9 |
| -from httpx import Response |
10 | 10 |
|
11 | 11 | from apify_client._errors import ApifyApiError
|
12 | 12 | from apify_client._utils import catch_not_found_or_throw, pluck_data
|
13 | 13 | from apify_client.clients.base import ResourceClient, ResourceClientAsync
|
14 | 14 |
|
| 15 | +logger = logging.getLogger(__name__) |
| 16 | + |
15 | 17 | _RQ_MAX_REQUESTS_PER_BATCH = 25
|
16 | 18 | _MAX_PAYLOAD_SIZE_BYTES = 9 * 1024 * 1024 # 9 MB
|
17 | 19 | _SAFETY_BUFFER_PERCENT = 0.01 / 100 # 0.01%
|
18 | 20 |
|
19 | 21 |
|
| 22 | +class BatchAddRequestsResult(TypedDict): |
| 23 | + """Result of the batch add requests operation.""" |
| 24 | + |
| 25 | + processed_requests: list[dict] |
| 26 | + unprocessed_requests: list[dict] |
| 27 | + |
| 28 | + |
20 | 29 | class RequestQueueClient(ResourceClient):
|
21 | 30 | """Sub-client for manipulating a single request queue."""
|
22 | 31 |
|
@@ -252,19 +261,15 @@ def batch_add_requests(
|
252 | 261 | requests: list[dict],
|
253 | 262 | *,
|
254 | 263 | forefront: bool | None = None,
|
255 |
| - max_unprocessed_requests_retries: int = 3, |
256 |
| - max_parallel: int = 5, |
257 |
| - min_delay_between_unprocessed_requests_retries: timedelta = timedelta(milliseconds=500), |
258 | 264 | ) -> dict:
|
259 | 265 | """Add requests to the queue.
|
260 | 266 |
|
261 | 267 | https://docs.apify.com/api/v2#/reference/request-queues/batch-request-operations/add-requests
|
262 | 268 |
|
263 | 269 | Args:
|
264 |
| - requests: list of the requests to add |
265 |
| - forefront: Whether to add the requests to the head or the end of the queue |
| 270 | + requests (list[dict]): list of the requests to add |
| 271 | + forefront (bool, optional): Whether to add the requests to the head or the end of the queue |
266 | 272 | """
|
267 |
| - # TODO |
268 | 273 | request_params = self._params(clientKey=self.client_key, forefront=forefront)
|
269 | 274 |
|
270 | 275 | response = self.http_client.call(
|
@@ -552,73 +557,117 @@ async def delete_request_lock(
|
552 | 557 | params=request_params,
|
553 | 558 | )
|
554 | 559 |
|
555 |
| - async def _batch_add_requests_inner( |
| 560 | + async def _batch_add_requests_worker( |
556 | 561 | self,
|
557 |
| - semaphore: asyncio.Semaphore, |
| 562 | + queue: asyncio.Queue, |
558 | 563 | request_params: dict,
|
559 |
| - batch: list[dict], |
560 |
| - ) -> Response: |
561 |
| - async with semaphore: |
562 |
| - return await self.http_client.call( |
563 |
| - url=self._url('requests/batch'), |
564 |
| - method='POST', |
565 |
| - params=request_params, |
566 |
| - json=batch, |
567 |
| - ) |
| 564 | + max_unprocessed_requests_retries: int, |
| 565 | + min_delay_between_unprocessed_requests_retries: timedelta, |
| 566 | + ) -> BatchAddRequestsResult: |
| 567 | + processed_requests = [] |
| 568 | + unprocessed_requests = [] |
| 569 | + |
| 570 | + # TODO: add retry logic |
| 571 | + |
| 572 | + try: |
| 573 | + while True: |
| 574 | + batch = await queue.get() |
| 575 | + |
| 576 | + response = await self.http_client.call( |
| 577 | + url=self._url('requests/batch'), |
| 578 | + method='POST', |
| 579 | + params=request_params, |
| 580 | + json=batch, |
| 581 | + ) |
| 582 | + |
| 583 | + response_parsed = parse_date_fields(pluck_data(response.json())) |
| 584 | + |
| 585 | + if 200 <= response.status_code <= 299: |
| 586 | + processed_requests.append(response_parsed) |
| 587 | + else: |
| 588 | + unprocessed_requests.append(response_parsed) |
| 589 | + |
| 590 | + except asyncio.CancelledError: |
| 591 | + logger.debug('Worker task was cancelled.') |
| 592 | + |
| 593 | + except Exception as exc: |
| 594 | + logger.warning('Worker task failed with an exception.', exc_info=exc) |
| 595 | + |
| 596 | + finally: |
| 597 | + queue.task_done() |
| 598 | + |
| 599 | + return { |
| 600 | + 'processed_requests': processed_requests, |
| 601 | + 'unprocessed_requests': unprocessed_requests, |
| 602 | + } |
568 | 603 |
|
569 | 604 | async def batch_add_requests(
|
570 | 605 | self: RequestQueueClientAsync,
|
571 | 606 | requests: list[dict],
|
572 | 607 | *,
|
573 | 608 | forefront: bool = False,
|
574 |
| - max_unprocessed_requests_retries: int = 3, |
575 | 609 | max_parallel: int = 5,
|
| 610 | + max_unprocessed_requests_retries: int = 3, |
576 | 611 | min_delay_between_unprocessed_requests_retries: timedelta = timedelta(milliseconds=500),
|
577 |
| - ) -> list[dict]: |
578 |
| - """Add requests to the queue. |
| 612 | + ) -> BatchAddRequestsResult: |
| 613 | + """Add requests to the request queue in batches. |
579 | 614 |
|
580 | 615 | https://docs.apify.com/api/v2#/reference/request-queues/batch-request-operations/add-requests
|
581 | 616 |
|
582 | 617 | Args:
|
583 |
| - requests: List of requests to add. |
584 |
| - forefront: Whether to add the requests to the head or the end of the queue. |
585 |
| - max_unprocessed_requests_retries: Number of retries for unprocessed requests. |
586 |
| - max_parallel: Maximum number of parallel operations. |
587 |
| - min_delay_between_unprocessed_requests_retries: Minimum delay between retries for unprocessed requests. |
| 618 | + requests: List of the requests to add. |
| 619 | + forefront: Whether to add the requests to the head or the end of the request queue. |
| 620 | + max_unprocessed_requests_retries: Number of retry API calls for unprocessed requests. |
| 621 | + max_parallel: Maximum number of parallel calls to the API. |
| 622 | + min_delay_between_unprocessed_requests_retries: Minimum delay between retry API calls for unprocessed requests. |
| 623 | +
|
| 624 | + Returns: |
| 625 | + Result of the operation with processed and unprocessed requests. |
588 | 626 | """
|
589 | 627 | payload_size_limit_bytes = _MAX_PAYLOAD_SIZE_BYTES - math.ceil(_MAX_PAYLOAD_SIZE_BYTES * _SAFETY_BUFFER_PERCENT)
|
590 | 628 |
|
591 |
| - tasks = set[asyncio.Task]() |
592 |
| - |
593 |
| - responses = list[dict]() |
594 |
| - |
595 | 629 | request_params = self._params(clientKey=self.client_key, forefront=forefront)
|
| 630 | + tasks = set[asyncio.Task]() |
| 631 | + queue: asyncio.Queue[list[dict]] = asyncio.Queue() |
596 | 632 |
|
597 |
| - semaphore = asyncio.Semaphore(max_parallel) |
598 |
| - |
599 |
| - number_of_iterations = math.ceil(len(requests) / _RQ_MAX_REQUESTS_PER_BATCH) |
| 633 | + # Get the number of request batches. |
| 634 | + number_of_batches = math.ceil(len(requests) / _RQ_MAX_REQUESTS_PER_BATCH) |
600 | 635 |
|
601 |
| - for i in range(number_of_iterations): |
| 636 | + # Split requests into batches and put them into the queue. |
| 637 | + for i in range(number_of_batches): |
602 | 638 | start = i * _RQ_MAX_REQUESTS_PER_BATCH
|
603 | 639 | end = (i + 1) * _RQ_MAX_REQUESTS_PER_BATCH
|
604 | 640 | batch = requests[start:end]
|
| 641 | + await queue.put(batch) |
605 | 642 |
|
| 643 | + # Start the worker tasks. |
| 644 | + for i in range(max_parallel): |
606 | 645 | task = asyncio.create_task(
|
607 |
| - coro=self._batch_add_requests_inner( |
608 |
| - semaphore=semaphore, |
609 |
| - request_params=request_params, |
610 |
| - batch=batch, |
| 646 | + self._batch_add_requests_worker( |
| 647 | + queue, |
| 648 | + request_params, |
| 649 | + max_unprocessed_requests_retries, |
| 650 | + min_delay_between_unprocessed_requests_retries, |
611 | 651 | ),
|
612 |
| - name=f'batch_add_requests_{i}', |
| 652 | + name=f'batch_add_requests_worker_{i}', |
613 | 653 | )
|
614 |
| - |
615 | 654 | tasks.add(task)
|
616 |
| - task.add_done_callback(lambda response: responses.append(response.result().json())) |
617 |
| - task.add_done_callback(lambda _: tasks.remove(task)) |
618 | 655 |
|
619 |
| - asyncio.gather(*tasks) |
| 656 | + # Wait for all batches to be processed. |
| 657 | + await queue.join() |
| 658 | + |
| 659 | + # Send cancel signals to all worker tasks. |
| 660 | + for task in tasks: |
| 661 | + task.cancel() |
620 | 662 |
|
621 |
| - return [parse_date_fields(pluck_data(response)) for response in responses] |
| 663 | + # Wait for all worker tasks to finish. |
| 664 | + results: list[BatchAddRequestsResult] = await asyncio.gather(*tasks) |
| 665 | + |
| 666 | + # Combine the results from all worker tasks. |
| 667 | + return { |
| 668 | + 'processed_requests': [req for result in results for req in result['processed_requests']], |
| 669 | + 'unprocessed_requests': [req for result in results for req in result['unprocessed_requests']], |
| 670 | + } |
622 | 671 |
|
623 | 672 | async def batch_delete_requests(self: RequestQueueClientAsync, requests: list[dict]) -> dict:
|
624 | 673 | """Delete given requests from the queue.
|
|
0 commit comments