5
5
import math
6
6
from dataclasses import dataclass
7
7
from datetime import timedelta
8
- from multiprocessing import process
9
- from typing import Any , TypedDict
8
+ from typing import TYPE_CHECKING , Any , TypedDict
10
9
11
10
from apify_shared .utils import filter_out_none_values_recursively , ignore_docs , parse_date_fields
11
+ from more_itertools import constrained_batches
12
12
13
13
from apify_client ._errors import ApifyApiError
14
14
from apify_client ._utils import catch_not_found_or_throw , pluck_data
15
15
from apify_client .clients .base import ResourceClient , ResourceClientAsync
16
16
17
+ if TYPE_CHECKING :
18
+ from collections .abc import Iterable
19
+
17
20
logger = logging .getLogger (__name__ )
18
21
19
22
_RQ_MAX_REQUESTS_PER_BATCH = 25
@@ -25,7 +28,7 @@ class BatchAddRequestsResult(TypedDict):
25
28
"""Result of the batch add requests operation.
26
29
27
30
Args:
28
- processed_requests: List of requests that were added.
31
+ processed_requests: List of successfully added requests .
29
32
unprocessed_requests: List of requests that failed to be added.
30
33
"""
31
34
@@ -39,10 +42,10 @@ class AddRequestsBatch:
39
42
40
43
Args:
41
44
requests: List of requests to be added to the request queue.
42
- num_of_retries: Number of retries for the batch.
45
+ num_of_retries: Number of times this batch has been retried .
43
46
"""
44
47
45
- requests : list [dict ]
48
+ requests : Iterable [dict ]
46
49
num_of_retries : int = 0
47
50
48
51
@@ -584,33 +587,40 @@ async def _batch_add_requests_worker(
584
587
max_unprocessed_requests_retries : int ,
585
588
min_delay_between_unprocessed_requests_retries : timedelta ,
586
589
) -> BatchAddRequestsResult :
590
+ """Worker function to process a batch of requests.
591
+
592
+ This worker will process batches from the queue, retrying requests that fail until the retry limit is reached.
593
+
594
+ Returns result containing lists of processed and unprocessed requests by the worker.
595
+ """
587
596
processed_requests = list [dict ]()
588
597
unprocessed_requests = list [dict ]()
589
598
590
599
while True :
600
+ # Get the next batch from the queue.
591
601
try :
592
602
batch = await queue .get ()
593
603
except asyncio .CancelledError :
594
604
break
595
605
596
606
try :
607
+ # Send the batch to the API.
597
608
response = await self .http_client .call (
598
609
url = self ._url ('requests/batch' ),
599
610
method = 'POST' ,
600
611
params = request_params ,
601
- json = batch .requests ,
612
+ json = list ( batch .requests ) ,
602
613
)
603
614
604
615
response_parsed = parse_date_fields (pluck_data (response .json ()))
605
616
606
- # If the request was not successful and the number of retries is less than the maximum,
607
- # put the batch back into the queue and retry the request later.
608
- if (not response .is_success ) and batch .num_of_retries < max_unprocessed_requests_retries :
617
+ # Retry if the request failed and the retry limit has not been reached.
618
+ if not response .is_success and batch .num_of_retries < max_unprocessed_requests_retries :
609
619
batch .num_of_retries += 1
610
620
await asyncio .sleep (min_delay_between_unprocessed_requests_retries .total_seconds ())
611
621
await queue .put (batch )
612
622
613
- # Otherwise, extract the processed and unprocessed requests from the response .
623
+ # Otherwise, add the processed/ unprocessed requests to their respective lists .
614
624
else :
615
625
processed_requests .extend (response_parsed .get ('processedRequests' , []))
616
626
unprocessed_requests .extend (response_parsed .get ('unprocessedRequests' , []))
@@ -619,6 +629,7 @@ async def _batch_add_requests_worker(
619
629
logger .warning (f'Error occurred while processing a batch of requests: { exc } ' )
620
630
621
631
finally :
632
+ # Mark the batch as done whether it succeeded or failed.
622
633
queue .task_done ()
623
634
624
635
return {
@@ -637,36 +648,36 @@ async def batch_add_requests(
637
648
) -> BatchAddRequestsResult :
638
649
"""Add requests to the request queue in batches.
639
650
640
- https://docs.apify.com/api/v2#/reference/request-queues/batch-request-operations/add-requests
651
+ Requests are split into batches based on size and processed in parallel.
641
652
642
653
Args:
643
- requests: List of the requests to add .
644
- forefront: Whether to add the requests to the head or the end of the request queue.
645
- max_unprocessed_requests_retries: Number of retry API calls for unprocessed requests .
646
- max_parallel: Maximum number of parallel calls to the API .
647
- min_delay_between_unprocessed_requests_retries: Minimum delay between retry API calls for unprocessed requests.
654
+ requests: List of requests to be added to the queue .
655
+ forefront: Whether to add requests to the front of the queue.
656
+ max_parallel: Maximum number of parallel tasks for API calls .
657
+ max_unprocessed_requests_retries: Number of retry attempts for unprocessed requests .
658
+ min_delay_between_unprocessed_requests_retries: Minimum delay between retry attempts for unprocessed requests.
648
659
649
660
Returns:
650
- Result of the operation with processed and unprocessed requests.
661
+ Result containing lists of processed and unprocessed requests.
651
662
"""
652
- payload_size_limit_bytes = _MAX_PAYLOAD_SIZE_BYTES - math .ceil (_MAX_PAYLOAD_SIZE_BYTES * _SAFETY_BUFFER_PERCENT )
653
- # TODO: payload size limit bytes
654
-
655
- request_params = self ._params (clientKey = self .client_key , forefront = forefront )
656
663
tasks = set [asyncio .Task ]()
657
664
queue : asyncio .Queue [AddRequestsBatch ] = asyncio .Queue ()
665
+ request_params = self ._params (clientKey = self .client_key , forefront = forefront )
658
666
659
- # Get the number of request batches.
660
- number_of_batches = math .ceil (len (requests ) / _RQ_MAX_REQUESTS_PER_BATCH )
667
+ # Compute the payload size limit to ensure it doesn't exceed the maximum allowed size.
668
+ payload_size_limit_bytes = _MAX_PAYLOAD_SIZE_BYTES - math .ceil (_MAX_PAYLOAD_SIZE_BYTES * _SAFETY_BUFFER_PERCENT )
669
+
670
+ # Split the requests into batches, constrained by the max payload size and max requests per batch.
671
+ batches = constrained_batches (
672
+ requests ,
673
+ max_size = payload_size_limit_bytes ,
674
+ max_count = _RQ_MAX_REQUESTS_PER_BATCH ,
675
+ )
661
676
662
- # Split requests into batches and put them into the queue.
663
- for i in range (number_of_batches ):
664
- start = i * _RQ_MAX_REQUESTS_PER_BATCH
665
- end = (i + 1 ) * _RQ_MAX_REQUESTS_PER_BATCH
666
- batch = AddRequestsBatch (requests [start :end ])
667
- await queue .put (batch )
677
+ for batch in batches :
678
+ await queue .put (AddRequestsBatch (batch ))
668
679
669
- # Start the worker tasks.
680
+ # Start a required number of worker tasks to process the batches .
670
681
for i in range (max_parallel ):
671
682
coro = self ._batch_add_requests_worker (
672
683
queue ,
@@ -680,13 +691,13 @@ async def batch_add_requests(
680
691
# Wait for all batches to be processed.
681
692
await queue .join ()
682
693
683
- # Send cancel signals to all worker tasks and wait for them to finish.
694
+ # Send cancellation signals to all worker tasks and wait for them to finish.
684
695
for task in tasks :
685
696
task .cancel ()
686
697
687
698
results : list [BatchAddRequestsResult ] = await asyncio .gather (* tasks )
688
699
689
- # Combine the results from all worker tasks and return them.
700
+ # Combine the results from all workers and return them.
690
701
processed_requests = []
691
702
unprocessed_requests = []
692
703
0 commit comments