Skip to content

Commit 8909dfd

Browse files
committed
Remove redundant retires from batch_add_requests
Align sync and async version to both swallow exception, but return unprocessed requets Add tests
1 parent e7bcf5c commit 8909dfd

File tree

2 files changed

+135
-73
lines changed

2 files changed

+135
-73
lines changed

src/apify_client/clients/resource_clients/request_queue.py

Lines changed: 44 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,9 @@
33
import asyncio
44
import logging
55
import math
6-
from dataclasses import dataclass
7-
from datetime import timedelta
6+
from collections.abc import Iterable
87
from queue import Queue
9-
from time import sleep
10-
from typing import TYPE_CHECKING, Any, TypedDict
8+
from typing import Any, TypedDict
119

1210
from apify_shared.utils import filter_out_none_values_recursively, ignore_docs, parse_date_fields
1311
from more_itertools import constrained_batches
@@ -16,9 +14,6 @@
1614
from apify_client._utils import catch_not_found_or_throw, pluck_data
1715
from apify_client.clients.base import ResourceClient, ResourceClientAsync
1816

19-
if TYPE_CHECKING:
20-
from collections.abc import Iterable
21-
2217
logger = logging.getLogger(__name__)
2318

2419
_RQ_MAX_REQUESTS_PER_BATCH = 25
@@ -41,17 +36,9 @@ class BatchAddRequestsResult(TypedDict):
4136
unprocessedRequests: list[dict]
4237

4338

44-
@dataclass
45-
class AddRequestsBatch:
46-
"""Batch of requests to add to the request queue.
47-
48-
Args:
49-
requests: List of requests to be added to the request queue.
50-
num_of_retries: Number of times this batch has been retried.
51-
"""
52-
53-
requests: Iterable[dict]
54-
num_of_retries: int = 0
39+
def _get_unprocessed_request_from_request(request: dict[str, str]) -> dict[str, str]:
40+
relevant_keys = {'url', 'uniqueKey', 'method'}
41+
return {key: value for key, value in request.items() if key in relevant_keys}
5542

5643

5744
class RequestQueueClient(ResourceClient):
@@ -297,8 +284,6 @@ def batch_add_requests(
297284
*,
298285
forefront: bool = False,
299286
max_parallel: int = 1,
300-
max_unprocessed_requests_retries: int = 3,
301-
min_delay_between_unprocessed_requests_retries: timedelta = timedelta(milliseconds=500),
302287
) -> BatchAddRequestsResult:
303288
"""Add requests to the request queue in batches.
304289
@@ -312,9 +297,6 @@ def batch_add_requests(
312297
max_parallel: Specifies the maximum number of parallel tasks for API calls. This is only applicable
313298
to the async client. For the sync client, this value must be set to 1, as parallel execution
314299
is not supported.
315-
max_unprocessed_requests_retries: Number of retry attempts for unprocessed requests.
316-
min_delay_between_unprocessed_requests_retries: Minimum delay between retry attempts for unprocessed
317-
requests.
318300
319301
Returns:
320302
Result containing lists of processed and unprocessed requests.
@@ -335,42 +317,43 @@ def batch_add_requests(
335317
)
336318

337319
# Put the batches into the queue for processing.
338-
queue = Queue[AddRequestsBatch]()
320+
queue = Queue[Iterable[dict]]()
339321

340-
for b in batches:
341-
queue.put(AddRequestsBatch(b))
322+
for batch in batches:
323+
queue.put(batch)
342324

343325
processed_requests = list[dict]()
344-
unprocessed_requests = list[dict]()
326+
unprocessed_requests = dict[str, dict]()
345327

346328
# Process all batches in the queue sequentially.
347329
while not queue.empty():
348-
batch = queue.get()
330+
request_batch = queue.get()
331+
# All requests are considered unprocessed unless explicitly mentioned in `processedRequests` response.
332+
for request in request_batch:
333+
unprocessed_requests[request['uniqueKey']] = _get_unprocessed_request_from_request(request)
349334

350-
# Send the batch to the API.
351-
response = self.http_client.call(
352-
url=self._url('requests/batch'),
353-
method='POST',
354-
params=request_params,
355-
json=list(batch.requests),
356-
timeout_secs=_MEDIUM_TIMEOUT,
357-
)
358-
359-
# Retry if the request failed and the retry limit has not been reached.
360-
if not response.is_success and batch.num_of_retries < max_unprocessed_requests_retries:
361-
batch.num_of_retries += 1
362-
sleep(min_delay_between_unprocessed_requests_retries.total_seconds())
363-
queue.put(batch)
335+
try:
336+
# Send the batch to the API.
337+
response = self.http_client.call(
338+
url=self._url('requests/batch'),
339+
method='POST',
340+
params=request_params,
341+
json=list(request_batch),
342+
timeout_secs=_MEDIUM_TIMEOUT,
343+
)
364344

365-
# Otherwise, add the processed/unprocessed requests to their respective lists.
366-
else:
367345
response_parsed = parse_date_fields(pluck_data(response.json()))
368346
processed_requests.extend(response_parsed.get('processedRequests', []))
369-
unprocessed_requests.extend(response_parsed.get('unprocessedRequests', []))
347+
348+
for processed_request in response_parsed.get('processedRequests', []):
349+
unprocessed_requests.pop(processed_request['uniqueKey'], None)
350+
351+
except Exception as exc:
352+
logger.warning(f'Error occurred while processing a batch of requests: {exc}')
370353

371354
return {
372355
'processedRequests': processed_requests,
373-
'unprocessedRequests': unprocessed_requests,
356+
'unprocessedRequests': list(unprocessed_requests.values()),
374357
}
375358

376359
def batch_delete_requests(self, requests: list[dict]) -> dict:
@@ -661,24 +644,26 @@ async def delete_request_lock(
661644

662645
async def _batch_add_requests_worker(
663646
self,
664-
queue: asyncio.Queue[AddRequestsBatch],
647+
queue: asyncio.Queue[Iterable[dict]],
665648
request_params: dict,
666-
max_unprocessed_requests_retries: int,
667-
min_delay_between_unprocessed_requests_retries: timedelta,
668649
) -> BatchAddRequestsResult:
669650
"""Worker function to process a batch of requests.
670651
671-
This worker will process batches from the queue, retrying requests that fail until the retry limit is reached.
652+
This worker will process batches from the queue.
672653
673654
Return result containing lists of processed and unprocessed requests by the worker.
674655
"""
675656
processed_requests = list[dict]()
676-
unprocessed_requests = list[dict]()
657+
unprocessed_requests = dict[str, dict]()
677658

678659
while True:
679660
# Get the next batch from the queue.
680661
try:
681-
batch = await queue.get()
662+
request_batch = await queue.get()
663+
# All requests are considered unprocessed unless explicitly mentioned in `processedRequests` response.
664+
for request in request_batch:
665+
unprocessed_requests[request['uniqueKey']] = _get_unprocessed_request_from_request(request)
666+
682667
except asyncio.CancelledError:
683668
break
684669

@@ -688,22 +673,15 @@ async def _batch_add_requests_worker(
688673
url=self._url('requests/batch'),
689674
method='POST',
690675
params=request_params,
691-
json=list(batch.requests),
676+
json=list(request_batch),
692677
timeout_secs=_MEDIUM_TIMEOUT,
693678
)
694679

695680
response_parsed = parse_date_fields(pluck_data(response.json()))
681+
processed_requests.extend(response_parsed.get('processedRequests', []))
696682

697-
# Retry if the request failed and the retry limit has not been reached.
698-
if not response.is_success and batch.num_of_retries < max_unprocessed_requests_retries:
699-
batch.num_of_retries += 1
700-
await asyncio.sleep(min_delay_between_unprocessed_requests_retries.total_seconds())
701-
await queue.put(batch)
702-
703-
# Otherwise, add the processed/unprocessed requests to their respective lists.
704-
else:
705-
processed_requests.extend(response_parsed.get('processedRequests', []))
706-
unprocessed_requests.extend(response_parsed.get('unprocessedRequests', []))
683+
for processed_request in response_parsed.get('processedRequests', []):
684+
unprocessed_requests.pop(processed_request['uniqueKey'], None)
707685

708686
except Exception as exc:
709687
logger.warning(f'Error occurred while processing a batch of requests: {exc}')
@@ -714,7 +692,7 @@ async def _batch_add_requests_worker(
714692

715693
return {
716694
'processedRequests': processed_requests,
717-
'unprocessedRequests': unprocessed_requests,
695+
'unprocessedRequests': list(unprocessed_requests.values()),
718696
}
719697

720698
async def batch_add_requests(
@@ -723,8 +701,6 @@ async def batch_add_requests(
723701
*,
724702
forefront: bool = False,
725703
max_parallel: int = 5,
726-
max_unprocessed_requests_retries: int = 3,
727-
min_delay_between_unprocessed_requests_retries: timedelta = timedelta(milliseconds=500),
728704
) -> BatchAddRequestsResult:
729705
"""Add requests to the request queue in batches.
730706
@@ -738,15 +714,12 @@ async def batch_add_requests(
738714
max_parallel: Specifies the maximum number of parallel tasks for API calls. This is only applicable
739715
to the async client. For the sync client, this value must be set to 1, as parallel execution
740716
is not supported.
741-
max_unprocessed_requests_retries: Number of retry attempts for unprocessed requests.
742-
min_delay_between_unprocessed_requests_retries: Minimum delay between retry attempts for unprocessed
743-
requests.
744717
745718
Returns:
746719
Result containing lists of processed and unprocessed requests.
747720
"""
748721
tasks = set[asyncio.Task]()
749-
queue: asyncio.Queue[AddRequestsBatch] = asyncio.Queue()
722+
queue: asyncio.Queue[Iterable[dict]] = asyncio.Queue()
750723
request_params = self._params(clientKey=self.client_key, forefront=forefront)
751724

752725
# Compute the payload size limit to ensure it doesn't exceed the maximum allowed size.
@@ -760,15 +733,13 @@ async def batch_add_requests(
760733
)
761734

762735
for batch in batches:
763-
await queue.put(AddRequestsBatch(batch))
736+
await queue.put(batch)
764737

765738
# Start a required number of worker tasks to process the batches.
766739
for i in range(max_parallel):
767740
coro = self._batch_add_requests_worker(
768741
queue,
769742
request_params,
770-
max_unprocessed_requests_retries,
771-
min_delay_between_unprocessed_requests_retries,
772743
)
773744
task = asyncio.create_task(coro, name=f'batch_add_requests_worker_{i}')
774745
tasks.add(task)
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
import respx
2+
3+
from apify_client import ApifyClient, ApifyClientAsync
4+
5+
_PARTIALLY_ADDED_BATCH_RESPONSE_CONTENT = """{
6+
"data": {
7+
"processedRequests": [
8+
{
9+
"requestId": "YiKoxjkaS9gjGTqhF",
10+
"uniqueKey": "http://example.com/1",
11+
"wasAlreadyPresent": true,
12+
"wasAlreadyHandled": false
13+
}
14+
],
15+
"unprocessedRequests": [
16+
{
17+
"uniqueKey": "http://example.com/2",
18+
"url": "http://example.com/2",
19+
"method": "GET"
20+
}
21+
]
22+
}
23+
}"""
24+
25+
26+
@respx.mock
27+
async def test_batch_not_processed_due_to_exception_async() -> None:
28+
"""Test that all requests are unprocessed unless explicitly stated by the server that they have been processed."""
29+
client = ApifyClientAsync(token='')
30+
31+
respx.route(method='POST', host='api.apify.com').mock(return_value=respx.MockResponse(401))
32+
requests = [
33+
{'uniqueKey': 'http://example.com/1', 'url': 'http://example.com/1', 'method': 'GET'},
34+
{'uniqueKey': 'http://example.com/2', 'url': 'http://example.com/2', 'method': 'GET'},
35+
]
36+
rq_client = client.request_queue(request_queue_id='whatever')
37+
38+
response = await rq_client.batch_add_requests(requests=requests)
39+
assert response['unprocessedRequests'] == requests
40+
41+
42+
@respx.mock
43+
async def test_batch_processed_partially_async() -> None:
44+
client = ApifyClientAsync(token='')
45+
46+
respx.route(method='POST', host='api.apify.com').mock(
47+
return_value=respx.MockResponse(200, content=_PARTIALLY_ADDED_BATCH_RESPONSE_CONTENT)
48+
)
49+
requests = [
50+
{'uniqueKey': 'http://example.com/1', 'url': 'http://example.com/1', 'method': 'GET'},
51+
{'uniqueKey': 'http://example.com/2', 'url': 'http://example.com/2', 'method': 'GET'},
52+
]
53+
rq_client = client.request_queue(request_queue_id='whatever')
54+
55+
response = await rq_client.batch_add_requests(requests=requests)
56+
assert requests[0]['uniqueKey'] in {request['uniqueKey'] for request in response['processedRequests']}
57+
assert response['unprocessedRequests'] == [requests[1]]
58+
59+
60+
@respx.mock
61+
def test_batch_not_processed_due_to_exception_sync() -> None:
62+
"""Test that all requests are unprocessed unless explicitly stated by the server that they have been processed."""
63+
client = ApifyClient(token='')
64+
65+
respx.route(method='POST', host='api.apify.com').mock(return_value=respx.MockResponse(401))
66+
requests = [
67+
{'uniqueKey': 'http://example.com/1', 'url': 'http://example.com/1', 'method': 'GET'},
68+
{'uniqueKey': 'http://example.com/2', 'url': 'http://example.com/2', 'method': 'GET'},
69+
]
70+
rq_client = client.request_queue(request_queue_id='whatever')
71+
72+
response = rq_client.batch_add_requests(requests=requests)
73+
assert response['unprocessedRequests'] == requests
74+
75+
76+
@respx.mock
77+
async def test_batch_processed_partially_sync() -> None:
78+
client = ApifyClient(token='')
79+
80+
respx.route(method='POST', host='api.apify.com').mock(
81+
return_value=respx.MockResponse(200, content=_PARTIALLY_ADDED_BATCH_RESPONSE_CONTENT)
82+
)
83+
requests = [
84+
{'uniqueKey': 'http://example.com/1', 'url': 'http://example.com/1', 'method': 'GET'},
85+
{'uniqueKey': 'http://example.com/2', 'url': 'http://example.com/2', 'method': 'GET'},
86+
]
87+
rq_client = client.request_queue(request_queue_id='whatever')
88+
89+
response = rq_client.batch_add_requests(requests=requests)
90+
assert requests[0]['uniqueKey'] in {request['uniqueKey'] for request in response['processedRequests']}
91+
assert response['unprocessedRequests'] == [requests[1]]

0 commit comments

Comments
 (0)