Skip to content

Commit f49ac31

Browse files
committed
first version of implementation with workers
1 parent 5dda869 commit f49ac31

File tree

1 file changed

+94
-45
lines changed

1 file changed

+94
-45
lines changed

src/apify_client/clients/resource_clients/request_queue.py

Lines changed: 94 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,31 @@
11
from __future__ import annotations
22

33
import asyncio
4+
import logging
45
import math
56
from datetime import timedelta
6-
from typing import Any
7+
from typing import Any, TypedDict
78

89
from apify_shared.utils import filter_out_none_values_recursively, ignore_docs, parse_date_fields
9-
from httpx import Response
1010

1111
from apify_client._errors import ApifyApiError
1212
from apify_client._utils import catch_not_found_or_throw, pluck_data
1313
from apify_client.clients.base import ResourceClient, ResourceClientAsync
1414

15+
logger = logging.getLogger(__name__)
16+
1517
_RQ_MAX_REQUESTS_PER_BATCH = 25
1618
_MAX_PAYLOAD_SIZE_BYTES = 9 * 1024 * 1024 # 9 MB
1719
_SAFETY_BUFFER_PERCENT = 0.01 / 100 # 0.01%
1820

1921

22+
class BatchAddRequestsResult(TypedDict):
23+
"""Result of the batch add requests operation."""
24+
25+
processed_requests: list[dict]
26+
unprocessed_requests: list[dict]
27+
28+
2029
class RequestQueueClient(ResourceClient):
2130
"""Sub-client for manipulating a single request queue."""
2231

@@ -252,19 +261,15 @@ def batch_add_requests(
252261
requests: list[dict],
253262
*,
254263
forefront: bool | None = None,
255-
max_unprocessed_requests_retries: int = 3,
256-
max_parallel: int = 5,
257-
min_delay_between_unprocessed_requests_retries: timedelta = timedelta(milliseconds=500),
258264
) -> dict:
259265
"""Add requests to the queue.
260266
261267
https://docs.apify.com/api/v2#/reference/request-queues/batch-request-operations/add-requests
262268
263269
Args:
264-
requests: list of the requests to add
265-
forefront: Whether to add the requests to the head or the end of the queue
270+
requests (list[dict]): list of the requests to add
271+
forefront (bool, optional): Whether to add the requests to the head or the end of the queue
266272
"""
267-
# TODO
268273
request_params = self._params(clientKey=self.client_key, forefront=forefront)
269274

270275
response = self.http_client.call(
@@ -552,73 +557,117 @@ async def delete_request_lock(
552557
params=request_params,
553558
)
554559

555-
async def _batch_add_requests_inner(
560+
async def _batch_add_requests_worker(
556561
self,
557-
semaphore: asyncio.Semaphore,
562+
queue: asyncio.Queue,
558563
request_params: dict,
559-
batch: list[dict],
560-
) -> Response:
561-
async with semaphore:
562-
return await self.http_client.call(
563-
url=self._url('requests/batch'),
564-
method='POST',
565-
params=request_params,
566-
json=batch,
567-
)
564+
max_unprocessed_requests_retries: int,
565+
min_delay_between_unprocessed_requests_retries: timedelta,
566+
) -> BatchAddRequestsResult:
567+
processed_requests = []
568+
unprocessed_requests = []
569+
570+
# TODO: add retry logic
571+
572+
try:
573+
while True:
574+
batch = await queue.get()
575+
576+
response = await self.http_client.call(
577+
url=self._url('requests/batch'),
578+
method='POST',
579+
params=request_params,
580+
json=batch,
581+
)
582+
583+
response_parsed = parse_date_fields(pluck_data(response.json()))
584+
585+
if 200 <= response.status_code <= 299:
586+
processed_requests.append(response_parsed)
587+
else:
588+
unprocessed_requests.append(response_parsed)
589+
590+
except asyncio.CancelledError:
591+
logger.debug('Worker task was cancelled.')
592+
593+
except Exception as exc:
594+
logger.warning('Worker task failed with an exception.', exc_info=exc)
595+
596+
finally:
597+
queue.task_done()
598+
599+
return {
600+
'processed_requests': processed_requests,
601+
'unprocessed_requests': unprocessed_requests,
602+
}
568603

569604
async def batch_add_requests(
570605
self: RequestQueueClientAsync,
571606
requests: list[dict],
572607
*,
573608
forefront: bool = False,
574-
max_unprocessed_requests_retries: int = 3,
575609
max_parallel: int = 5,
610+
max_unprocessed_requests_retries: int = 3,
576611
min_delay_between_unprocessed_requests_retries: timedelta = timedelta(milliseconds=500),
577-
) -> list[dict]:
578-
"""Add requests to the queue.
612+
) -> BatchAddRequestsResult:
613+
"""Add requests to the request queue in batches.
579614
580615
https://docs.apify.com/api/v2#/reference/request-queues/batch-request-operations/add-requests
581616
582617
Args:
583-
requests: List of requests to add.
584-
forefront: Whether to add the requests to the head or the end of the queue.
585-
max_unprocessed_requests_retries: Number of retries for unprocessed requests.
586-
max_parallel: Maximum number of parallel operations.
587-
min_delay_between_unprocessed_requests_retries: Minimum delay between retries for unprocessed requests.
618+
requests: List of the requests to add.
619+
forefront: Whether to add the requests to the head or the end of the request queue.
620+
max_unprocessed_requests_retries: Number of retry API calls for unprocessed requests.
621+
max_parallel: Maximum number of parallel calls to the API.
622+
min_delay_between_unprocessed_requests_retries: Minimum delay between retry API calls for unprocessed requests.
623+
624+
Returns:
625+
Result of the operation with processed and unprocessed requests.
588626
"""
589627
payload_size_limit_bytes = _MAX_PAYLOAD_SIZE_BYTES - math.ceil(_MAX_PAYLOAD_SIZE_BYTES * _SAFETY_BUFFER_PERCENT)
590628

591-
tasks = set[asyncio.Task]()
592-
593-
responses = list[dict]()
594-
595629
request_params = self._params(clientKey=self.client_key, forefront=forefront)
630+
tasks = set[asyncio.Task]()
631+
queue: asyncio.Queue[list[dict]] = asyncio.Queue()
596632

597-
semaphore = asyncio.Semaphore(max_parallel)
598-
599-
number_of_iterations = math.ceil(len(requests) / _RQ_MAX_REQUESTS_PER_BATCH)
633+
# Get the number of request batches.
634+
number_of_batches = math.ceil(len(requests) / _RQ_MAX_REQUESTS_PER_BATCH)
600635

601-
for i in range(number_of_iterations):
636+
# Split requests into batches and put them into the queue.
637+
for i in range(number_of_batches):
602638
start = i * _RQ_MAX_REQUESTS_PER_BATCH
603639
end = (i + 1) * _RQ_MAX_REQUESTS_PER_BATCH
604640
batch = requests[start:end]
641+
await queue.put(batch)
605642

643+
# Start the worker tasks.
644+
for i in range(max_parallel):
606645
task = asyncio.create_task(
607-
coro=self._batch_add_requests_inner(
608-
semaphore=semaphore,
609-
request_params=request_params,
610-
batch=batch,
646+
self._batch_add_requests_worker(
647+
queue,
648+
request_params,
649+
max_unprocessed_requests_retries,
650+
min_delay_between_unprocessed_requests_retries,
611651
),
612-
name=f'batch_add_requests_{i}',
652+
name=f'batch_add_requests_worker_{i}',
613653
)
614-
615654
tasks.add(task)
616-
task.add_done_callback(lambda response: responses.append(response.result().json()))
617-
task.add_done_callback(lambda _: tasks.remove(task))
618655

619-
asyncio.gather(*tasks)
656+
# Wait for all batches to be processed.
657+
await queue.join()
658+
659+
# Send cancel signals to all worker tasks.
660+
for task in tasks:
661+
task.cancel()
620662

621-
return [parse_date_fields(pluck_data(response)) for response in responses]
663+
# Wait for all worker tasks to finish.
664+
results: list[BatchAddRequestsResult] = await asyncio.gather(*tasks)
665+
666+
# Combine the results from all worker tasks.
667+
return {
668+
'processed_requests': [req for result in results for req in result['processed_requests']],
669+
'unprocessed_requests': [req for result in results for req in result['unprocessed_requests']],
670+
}
622671

623672
async def batch_delete_requests(self: RequestQueueClientAsync, requests: list[dict]) -> dict:
624673
"""Delete given requests from the queue.

0 commit comments

Comments
 (0)