Skip to content

Commit a66c249

Browse files
committed
use constrained batches from more itertools
1 parent 6f6394a commit a66c249

File tree

1 file changed

+43
-32
lines changed

1 file changed

+43
-32
lines changed

src/apify_client/clients/resource_clients/request_queue.py

Lines changed: 43 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,18 @@
55
import math
66
from dataclasses import dataclass
77
from datetime import timedelta
8-
from multiprocessing import process
9-
from typing import Any, TypedDict
8+
from typing import TYPE_CHECKING, Any, TypedDict
109

1110
from apify_shared.utils import filter_out_none_values_recursively, ignore_docs, parse_date_fields
11+
from more_itertools import constrained_batches
1212

1313
from apify_client._errors import ApifyApiError
1414
from apify_client._utils import catch_not_found_or_throw, pluck_data
1515
from apify_client.clients.base import ResourceClient, ResourceClientAsync
1616

17+
if TYPE_CHECKING:
18+
from collections.abc import Iterable
19+
1720
logger = logging.getLogger(__name__)
1821

1922
_RQ_MAX_REQUESTS_PER_BATCH = 25
@@ -25,7 +28,7 @@ class BatchAddRequestsResult(TypedDict):
2528
"""Result of the batch add requests operation.
2629
2730
Args:
28-
processed_requests: List of requests that were added.
31+
processed_requests: List of successfully added requests.
2932
unprocessed_requests: List of requests that failed to be added.
3033
"""
3134

@@ -39,10 +42,10 @@ class AddRequestsBatch:
3942
4043
Args:
4144
requests: List of requests to be added to the request queue.
42-
num_of_retries: Number of retries for the batch.
45+
num_of_retries: Number of times this batch has been retried.
4346
"""
4447

45-
requests: list[dict]
48+
requests: Iterable[dict]
4649
num_of_retries: int = 0
4750

4851

@@ -584,33 +587,40 @@ async def _batch_add_requests_worker(
584587
max_unprocessed_requests_retries: int,
585588
min_delay_between_unprocessed_requests_retries: timedelta,
586589
) -> BatchAddRequestsResult:
590+
"""Worker function to process a batch of requests.
591+
592+
This worker will process batches from the queue, retrying requests that fail until the retry limit is reached.
593+
594+
Returns result containing lists of processed and unprocessed requests by the worker.
595+
"""
587596
processed_requests = list[dict]()
588597
unprocessed_requests = list[dict]()
589598

590599
while True:
600+
# Get the next batch from the queue.
591601
try:
592602
batch = await queue.get()
593603
except asyncio.CancelledError:
594604
break
595605

596606
try:
607+
# Send the batch to the API.
597608
response = await self.http_client.call(
598609
url=self._url('requests/batch'),
599610
method='POST',
600611
params=request_params,
601-
json=batch.requests,
612+
json=list(batch.requests),
602613
)
603614

604615
response_parsed = parse_date_fields(pluck_data(response.json()))
605616

606-
# If the request was not successful and the number of retries is less than the maximum,
607-
# put the batch back into the queue and retry the request later.
608-
if (not response.is_success) and batch.num_of_retries < max_unprocessed_requests_retries:
617+
# Retry if the request failed and the retry limit has not been reached.
618+
if not response.is_success and batch.num_of_retries < max_unprocessed_requests_retries:
609619
batch.num_of_retries += 1
610620
await asyncio.sleep(min_delay_between_unprocessed_requests_retries.total_seconds())
611621
await queue.put(batch)
612622

613-
# Otherwise, extract the processed and unprocessed requests from the response.
623+
# Otherwise, add the processed/unprocessed requests to their respective lists.
614624
else:
615625
processed_requests.extend(response_parsed.get('processedRequests', []))
616626
unprocessed_requests.extend(response_parsed.get('unprocessedRequests', []))
@@ -619,6 +629,7 @@ async def _batch_add_requests_worker(
619629
logger.warning(f'Error occurred while processing a batch of requests: {exc}')
620630

621631
finally:
632+
# Mark the batch as done whether it succeeded or failed.
622633
queue.task_done()
623634

624635
return {
@@ -637,36 +648,36 @@ async def batch_add_requests(
637648
) -> BatchAddRequestsResult:
638649
"""Add requests to the request queue in batches.
639650
640-
https://docs.apify.com/api/v2#/reference/request-queues/batch-request-operations/add-requests
651+
Requests are split into batches based on size and processed in parallel.
641652
642653
Args:
643-
requests: List of the requests to add.
644-
forefront: Whether to add the requests to the head or the end of the request queue.
645-
max_unprocessed_requests_retries: Number of retry API calls for unprocessed requests.
646-
max_parallel: Maximum number of parallel calls to the API.
647-
min_delay_between_unprocessed_requests_retries: Minimum delay between retry API calls for unprocessed requests.
654+
requests: List of requests to be added to the queue.
655+
forefront: Whether to add requests to the front of the queue.
656+
max_parallel: Maximum number of parallel tasks for API calls.
657+
max_unprocessed_requests_retries: Number of retry attempts for unprocessed requests.
658+
min_delay_between_unprocessed_requests_retries: Minimum delay between retry attempts for unprocessed requests.
648659
649660
Returns:
650-
Result of the operation with processed and unprocessed requests.
661+
Result containing lists of processed and unprocessed requests.
651662
"""
652-
payload_size_limit_bytes = _MAX_PAYLOAD_SIZE_BYTES - math.ceil(_MAX_PAYLOAD_SIZE_BYTES * _SAFETY_BUFFER_PERCENT)
653-
# TODO: payload size limit bytes
654-
655-
request_params = self._params(clientKey=self.client_key, forefront=forefront)
656663
tasks = set[asyncio.Task]()
657664
queue: asyncio.Queue[AddRequestsBatch] = asyncio.Queue()
665+
request_params = self._params(clientKey=self.client_key, forefront=forefront)
658666

659-
# Get the number of request batches.
660-
number_of_batches = math.ceil(len(requests) / _RQ_MAX_REQUESTS_PER_BATCH)
667+
# Compute the payload size limit to ensure it doesn't exceed the maximum allowed size.
668+
payload_size_limit_bytes = _MAX_PAYLOAD_SIZE_BYTES - math.ceil(_MAX_PAYLOAD_SIZE_BYTES * _SAFETY_BUFFER_PERCENT)
669+
670+
# Split the requests into batches, constrained by the max payload size and max requests per batch.
671+
batches = constrained_batches(
672+
requests,
673+
max_size=payload_size_limit_bytes,
674+
max_count=_RQ_MAX_REQUESTS_PER_BATCH,
675+
)
661676

662-
# Split requests into batches and put them into the queue.
663-
for i in range(number_of_batches):
664-
start = i * _RQ_MAX_REQUESTS_PER_BATCH
665-
end = (i + 1) * _RQ_MAX_REQUESTS_PER_BATCH
666-
batch = AddRequestsBatch(requests[start:end])
667-
await queue.put(batch)
677+
for batch in batches:
678+
await queue.put(AddRequestsBatch(batch))
668679

669-
# Start the worker tasks.
680+
# Start a required number of worker tasks to process the batches.
670681
for i in range(max_parallel):
671682
coro = self._batch_add_requests_worker(
672683
queue,
@@ -680,13 +691,13 @@ async def batch_add_requests(
680691
# Wait for all batches to be processed.
681692
await queue.join()
682693

683-
# Send cancel signals to all worker tasks and wait for them to finish.
694+
# Send cancellation signals to all worker tasks and wait for them to finish.
684695
for task in tasks:
685696
task.cancel()
686697

687698
results: list[BatchAddRequestsResult] = await asyncio.gather(*tasks)
688699

689-
# Combine the results from all worker tasks and return them.
700+
# Combine the results from all workers and return them.
690701
processed_requests = []
691702
unprocessed_requests = []
692703

0 commit comments

Comments
 (0)