Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 18 additions & 70 deletions src/apify_client/clients/resource_clients/request_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,9 @@
import asyncio
import logging
import math
from dataclasses import dataclass
from datetime import timedelta
from collections.abc import Iterable
from queue import Queue
from time import sleep
from typing import TYPE_CHECKING, Any, TypedDict
from typing import Any, TypedDict

from apify_shared.utils import filter_out_none_values_recursively, ignore_docs, parse_date_fields
from more_itertools import constrained_batches
Expand All @@ -16,9 +14,6 @@
from apify_client._utils import catch_not_found_or_throw, pluck_data
from apify_client.clients.base import ResourceClient, ResourceClientAsync

if TYPE_CHECKING:
from collections.abc import Iterable

logger = logging.getLogger(__name__)

_RQ_MAX_REQUESTS_PER_BATCH = 25
Expand All @@ -41,19 +36,6 @@ class BatchAddRequestsResult(TypedDict):
unprocessedRequests: list[dict]


@dataclass
class AddRequestsBatch:
"""Batch of requests to add to the request queue.

Args:
requests: List of requests to be added to the request queue.
num_of_retries: Number of times this batch has been retried.
"""

requests: Iterable[dict]
num_of_retries: int = 0


class RequestQueueClient(ResourceClient):
"""Sub-client for manipulating a single request queue."""

Expand Down Expand Up @@ -297,8 +279,6 @@ def batch_add_requests(
*,
forefront: bool = False,
max_parallel: int = 1,
max_unprocessed_requests_retries: int = 3,
min_delay_between_unprocessed_requests_retries: timedelta = timedelta(milliseconds=500),
) -> BatchAddRequestsResult:
"""Add requests to the request queue in batches.

Expand All @@ -312,9 +292,6 @@ def batch_add_requests(
max_parallel: Specifies the maximum number of parallel tasks for API calls. This is only applicable
to the async client. For the sync client, this value must be set to 1, as parallel execution
is not supported.
max_unprocessed_requests_retries: Number of retry attempts for unprocessed requests.
min_delay_between_unprocessed_requests_retries: Minimum delay between retry attempts for unprocessed
requests.

Returns:
Result containing lists of processed and unprocessed requests.
Expand All @@ -335,38 +312,30 @@ def batch_add_requests(
)

# Put the batches into the queue for processing.
queue = Queue[AddRequestsBatch]()
queue = Queue[Iterable[dict]]()

for b in batches:
queue.put(AddRequestsBatch(b))
for batch in batches:
queue.put(batch)

processed_requests = list[dict]()
unprocessed_requests = list[dict]()

# Process all batches in the queue sequentially.
while not queue.empty():
batch = queue.get()
request_batch = queue.get()

# Send the batch to the API.
response = self.http_client.call(
url=self._url('requests/batch'),
method='POST',
params=request_params,
json=list(batch.requests),
json=list(request_batch),
timeout_secs=_MEDIUM_TIMEOUT,
)

# Retry if the request failed and the retry limit has not been reached.
if not response.is_success and batch.num_of_retries < max_unprocessed_requests_retries:
batch.num_of_retries += 1
sleep(min_delay_between_unprocessed_requests_retries.total_seconds())
queue.put(batch)

# Otherwise, add the processed/unprocessed requests to their respective lists.
else:
response_parsed = parse_date_fields(pluck_data(response.json()))
processed_requests.extend(response_parsed.get('processedRequests', []))
unprocessed_requests.extend(response_parsed.get('unprocessedRequests', []))
response_parsed = parse_date_fields(pluck_data(response.json()))
processed_requests.extend(response_parsed.get('processedRequests', []))
unprocessed_requests.extend(response_parsed.get('unprocessedRequests', []))

return {
'processedRequests': processed_requests,
Expand Down Expand Up @@ -661,14 +630,12 @@ async def delete_request_lock(

async def _batch_add_requests_worker(
self,
queue: asyncio.Queue[AddRequestsBatch],
queue: asyncio.Queue[Iterable[dict]],
request_params: dict,
max_unprocessed_requests_retries: int,
min_delay_between_unprocessed_requests_retries: timedelta,
) -> BatchAddRequestsResult:
"""Worker function to process a batch of requests.

This worker will process batches from the queue, retrying requests that fail until the retry limit is reached.
This worker will process batches from the queue.

Return result containing lists of processed and unprocessed requests by the worker.
"""
Expand All @@ -678,7 +645,7 @@ async def _batch_add_requests_worker(
while True:
# Get the next batch from the queue.
try:
batch = await queue.get()
request_batch = await queue.get()
except asyncio.CancelledError:
break

Expand All @@ -688,25 +655,13 @@ async def _batch_add_requests_worker(
url=self._url('requests/batch'),
method='POST',
params=request_params,
json=list(batch.requests),
json=list(request_batch),
timeout_secs=_MEDIUM_TIMEOUT,
)

response_parsed = parse_date_fields(pluck_data(response.json()))

# Retry if the request failed and the retry limit has not been reached.
if not response.is_success and batch.num_of_retries < max_unprocessed_requests_retries:
batch.num_of_retries += 1
await asyncio.sleep(min_delay_between_unprocessed_requests_retries.total_seconds())
await queue.put(batch)

# Otherwise, add the processed/unprocessed requests to their respective lists.
else:
processed_requests.extend(response_parsed.get('processedRequests', []))
unprocessed_requests.extend(response_parsed.get('unprocessedRequests', []))

except Exception as exc:
logger.warning(f'Error occurred while processing a batch of requests: {exc}')
processed_requests.extend(response_parsed.get('processedRequests', []))
unprocessed_requests.extend(response_parsed.get('unprocessedRequests', []))

finally:
# Mark the batch as done whether it succeeded or failed.
Expand All @@ -723,8 +678,6 @@ async def batch_add_requests(
*,
forefront: bool = False,
max_parallel: int = 5,
max_unprocessed_requests_retries: int = 3,
min_delay_between_unprocessed_requests_retries: timedelta = timedelta(milliseconds=500),
) -> BatchAddRequestsResult:
"""Add requests to the request queue in batches.

Expand All @@ -738,15 +691,12 @@ async def batch_add_requests(
max_parallel: Specifies the maximum number of parallel tasks for API calls. This is only applicable
to the async client. For the sync client, this value must be set to 1, as parallel execution
is not supported.
max_unprocessed_requests_retries: Number of retry attempts for unprocessed requests.
min_delay_between_unprocessed_requests_retries: Minimum delay between retry attempts for unprocessed
requests.

Returns:
Result containing lists of processed and unprocessed requests.
"""
tasks = set[asyncio.Task]()
queue: asyncio.Queue[AddRequestsBatch] = asyncio.Queue()
queue: asyncio.Queue[Iterable[dict]] = asyncio.Queue()
request_params = self._params(clientKey=self.client_key, forefront=forefront)

# Compute the payload size limit to ensure it doesn't exceed the maximum allowed size.
Expand All @@ -760,15 +710,13 @@ async def batch_add_requests(
)

for batch in batches:
await queue.put(AddRequestsBatch(batch))
await queue.put(batch)

# Start a required number of worker tasks to process the batches.
for i in range(max_parallel):
coro = self._batch_add_requests_worker(
queue,
request_params,
max_unprocessed_requests_retries,
min_delay_between_unprocessed_requests_retries,
)
task = asyncio.create_task(coro, name=f'batch_add_requests_worker_{i}')
tasks.add(task)
Expand Down
93 changes: 93 additions & 0 deletions tests/unit/test_client_request_queue.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import pytest
import respx

import apify_client
from apify_client import ApifyClient, ApifyClientAsync

_PARTIALLY_ADDED_BATCH_RESPONSE_CONTENT = """{
"data": {
"processedRequests": [
{
"requestId": "YiKoxjkaS9gjGTqhF",
"uniqueKey": "http://example.com/1",
"wasAlreadyPresent": true,
"wasAlreadyHandled": false
}
],
"unprocessedRequests": [
{
"uniqueKey": "http://example.com/2",
"url": "http://example.com/2",
"method": "GET"
}
]
}
}"""


@respx.mock
async def test_batch_not_processed_raises_exception_async() -> None:
"""Test that client exceptions are not silently ignored"""
client = ApifyClientAsync(token='')

respx.route(method='POST', host='api.apify.com').mock(return_value=respx.MockResponse(401))
requests = [
{'uniqueKey': 'http://example.com/1', 'url': 'http://example.com/1', 'method': 'GET'},
{'uniqueKey': 'http://example.com/2', 'url': 'http://example.com/2', 'method': 'GET'},
]
rq_client = client.request_queue(request_queue_id='whatever')

with pytest.raises(apify_client._errors.ApifyApiError):
await rq_client.batch_add_requests(requests=requests)


@respx.mock
async def test_batch_processed_partially_async() -> None:
client = ApifyClientAsync(token='')

respx.route(method='POST', host='api.apify.com').mock(
return_value=respx.MockResponse(200, content=_PARTIALLY_ADDED_BATCH_RESPONSE_CONTENT)
)
requests = [
{'uniqueKey': 'http://example.com/1', 'url': 'http://example.com/1', 'method': 'GET'},
{'uniqueKey': 'http://example.com/2', 'url': 'http://example.com/2', 'method': 'GET'},
]
rq_client = client.request_queue(request_queue_id='whatever')

response = await rq_client.batch_add_requests(requests=requests)
assert requests[0]['uniqueKey'] in {request['uniqueKey'] for request in response['processedRequests']}
assert response['unprocessedRequests'] == [requests[1]]


@respx.mock
def test_batch_not_processed_raises_exception_sync() -> None:
"""Test that client exceptions are not silently ignored"""
client = ApifyClient(token='')

respx.route(method='POST', host='api.apify.com').mock(return_value=respx.MockResponse(401))
requests = [
{'uniqueKey': 'http://example.com/1', 'url': 'http://example.com/1', 'method': 'GET'},
{'uniqueKey': 'http://example.com/2', 'url': 'http://example.com/2', 'method': 'GET'},
]
rq_client = client.request_queue(request_queue_id='whatever')

with pytest.raises(apify_client._errors.ApifyApiError):
rq_client.batch_add_requests(requests=requests)


@respx.mock
async def test_batch_processed_partially_sync() -> None:
client = ApifyClient(token='')

respx.route(method='POST', host='api.apify.com').mock(
return_value=respx.MockResponse(200, content=_PARTIALLY_ADDED_BATCH_RESPONSE_CONTENT)
)
requests = [
{'uniqueKey': 'http://example.com/1', 'url': 'http://example.com/1', 'method': 'GET'},
{'uniqueKey': 'http://example.com/2', 'url': 'http://example.com/2', 'method': 'GET'},
]
rq_client = client.request_queue(request_queue_id='whatever')

response = rq_client.batch_add_requests(requests=requests)
assert requests[0]['uniqueKey'] in {request['uniqueKey'] for request in response['processedRequests']}
assert response['unprocessedRequests'] == [requests[1]]