10
10
import re
11
11
import string
12
12
import unicodedata
13
- from collections .abc import Awaitable , Collection , Iterable , Iterator , Mapping
13
+ from collections .abc import Awaitable , Callable , Collection , Iterable , Iterator , Mapping
14
14
from datetime import datetime
15
15
from functools import reduce
16
16
from http import HTTPStatus
26
26
from pybtex .style .formatting import unsrtalpha
27
27
from pybtex .style .template import FieldIsMissing
28
28
from tenacity import (
29
+ AsyncRetrying ,
29
30
before_sleep_log ,
30
- retry ,
31
31
retry_if_exception ,
32
32
stop_after_attempt ,
33
33
wait_incrementing ,
@@ -402,7 +402,10 @@ def create_bibtex_key(author: list[str], year: str | int, title: str) -> str:
402
402
return remove_substrings (key , FORBIDDEN_KEY_CHARACTERS )
403
403
404
404
405
- def is_retryable (exc : BaseException ) -> bool :
405
+ def is_retryable (
406
+ exc : BaseException ,
407
+ additional_status_codes : Collection [HTTPStatus | int ] | None = None ,
408
+ ) -> bool :
406
409
"""Check if an exception is known to be a retryable HTTP issue."""
407
410
if isinstance (
408
411
exc , aiohttp .ServerDisconnectedError | aiohttp .ClientConnectionResetError
@@ -411,33 +414,56 @@ def is_retryable(exc: BaseException) -> bool:
411
414
# > aiohttp.client_exceptions.ClientConnectionResetError:
412
415
# > Cannot write to closing transport
413
416
return True
414
- return isinstance ( exc , aiohttp . ClientResponseError ) and exc . status in {
417
+ retry_status_codes : set [ int ] = {
415
418
httpx .codes .INTERNAL_SERVER_ERROR .value ,
416
419
httpx .codes .GATEWAY_TIMEOUT .value ,
417
420
}
421
+ if additional_status_codes :
422
+ retry_status_codes .update (
423
+ status_code .value if isinstance (status_code , HTTPStatus ) else status_code
424
+ for status_code in additional_status_codes
425
+ )
426
+ return (
427
+ isinstance (exc , aiohttp .ClientResponseError )
428
+ and exc .status in retry_status_codes
429
+ )
418
430
419
431
420
- @retry (
421
- retry = retry_if_exception (is_retryable ),
422
- before_sleep = before_sleep_log (logger , logging .WARNING ),
423
- stop = stop_after_attempt (5 ),
424
- wait = wait_incrementing (0.1 , 0.1 ),
425
- )
426
- async def _get_with_retrying (
432
+ async def _get_with_retrying ( # type: ignore[return] # noqa: RET503
427
433
url : str ,
428
434
session : aiohttp .ClientSession ,
429
435
http_exception_mappings : dict [HTTPStatus | int , Exception ] | None = None ,
436
+ retry_predicate : Callable [[BaseException ], bool ] = is_retryable ,
430
437
** get_kwargs ,
431
438
) -> dict [str , Any ]:
432
- """Get from a URL with retrying protection."""
433
- try :
434
- async with session .get (url , ** get_kwargs ) as response :
435
- response .raise_for_status ()
436
- return await response .json ()
437
- except aiohttp .ClientResponseError as e :
438
- if http_exception_mappings and e .status in http_exception_mappings :
439
- raise http_exception_mappings [e .status ] from e
440
- raise
439
+ """Get from a URL with retrying protection.
440
+
441
+ Args:
442
+ url: Target URL for the GET request.
443
+ session: Session for the GET request.
444
+ http_exception_mappings: Optional mapping of HTTP status codes to
445
+ custom Exceptions to be thrown.
446
+ retry_predicate: Optional predicate to dictate when to retry.
447
+ **get_kwargs: Optional additional keyword arguments for the GET request.
448
+
449
+ Returns:
450
+ JSON result from the GET request.
451
+ """
452
+ async for attempt in AsyncRetrying (
453
+ retry = retry_if_exception (retry_predicate ),
454
+ before_sleep = before_sleep_log (logger , logging .WARNING ),
455
+ stop = stop_after_attempt (5 ),
456
+ wait = wait_incrementing (0.1 , 0.1 ),
457
+ ):
458
+ with attempt :
459
+ try :
460
+ async with session .get (url , ** get_kwargs ) as response :
461
+ response .raise_for_status ()
462
+ return await response .json ()
463
+ except aiohttp .ClientResponseError as e :
464
+ if http_exception_mappings and e .status in http_exception_mappings :
465
+ raise http_exception_mappings [e .status ] from e
466
+ raise
441
467
442
468
443
469
def union_collections_to_ordered_list (collections : Iterable ) -> list :
0 commit comments