Skip to content

Commit 440dd91

Browse files
authored
Merge pull request #83 from GitGuardian/agateau/report-rate-limit
Handle API rate-limits
2 parents 6e0d454 + b52f1fd commit 440dd91

File tree

5 files changed

+130
-19
lines changed

5 files changed

+130
-19
lines changed
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
### Added
2+
3+
- GGClient now obeys rate-limits and can notify callers when hitting one.

pygitguardian/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
"""PyGitGuardian API Client"""
2-
from .client import ContentTooLarge, GGClient
2+
from .client import ContentTooLarge, GGClient, GGClientCallbacks
33

44

55
__version__ = "1.11.0"
66
GGClient._version = __version__
77

8-
__all__ = ["GGClient", "ContentTooLarge"]
8+
__all__ = ["GGClient", "GGClientCallbacks", "ContentTooLarge"]

pygitguardian/client.py

Lines changed: 45 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import tarfile
55
import time
66
import urllib.parse
7+
from abc import ABC, abstractmethod
78
from io import BytesIO
89
from pathlib import Path
910
from typing import Any, Dict, List, Optional, Union, cast
@@ -46,6 +47,8 @@
4647
# max files size to create a tar from
4748
MAX_TAR_CONTENT_SIZE = 30 * 1024 * 1024
4849

50+
HTTP_TOO_MANY_REQUESTS = 429
51+
4952

5053
class ContentTooLarge(Exception):
5154
"""
@@ -124,6 +127,15 @@ def _create_tar(root_path: Path, filenames: List[str]) -> bytes:
124127
return tar_stream.getvalue()
125128

126129

130+
class GGClientCallbacks(ABC):
131+
"""Abstract class used to notify GGClient users of events"""
132+
133+
@abstractmethod
134+
def on_rate_limited(self, delay: int) -> None:
135+
"""Called when GGClient hits a rate-limit."""
136+
... # pragma: no cover
137+
138+
127139
class GGClient:
128140
_version = "undefined"
129141
session: Session
@@ -133,6 +145,7 @@ class GGClient:
133145
user_agent: str
134146
extra_headers: Dict
135147
secret_scan_preferences: SecretScanPreferences
148+
callbacks: Optional[GGClientCallbacks]
136149

137150
def __init__(
138151
self,
@@ -141,13 +154,15 @@ def __init__(
141154
session: Optional[Session] = None,
142155
user_agent: Optional[str] = None,
143156
timeout: Optional[float] = DEFAULT_TIMEOUT,
157+
callbacks: Optional[GGClientCallbacks] = None,
144158
):
145159
"""
146160
:param api_key: API Key to be added to requests
147161
:param base_uri: Base URI for the API, defaults to "https://api.gitguardian.com"
148162
:param session: custom requests session, defaults to requests.Session()
149163
:param user_agent: user agent to identify requests, defaults to ""
150164
:param timeout: request timeout, defaults to 20s
165+
:param callbacks: object used to receive callbacks from the client, defaults to None
151166
152167
:raises ValueError: if the protocol or the api_key is invalid
153168
"""
@@ -177,6 +192,7 @@ def __init__(
177192
self.api_key = api_key
178193
self.session = session if isinstance(session, Session) else Session()
179194
self.timeout = timeout
195+
self.callbacks = callbacks
180196
self.user_agent = "pygitguardian/{} ({};py{})".format(
181197
self._version, platform.system(), platform.python_version()
182198
)
@@ -207,18 +223,35 @@ def request(
207223
if extra_headers
208224
else self.session.headers
209225
)
210-
start = time.time()
211-
response: Response = self.session.request(
212-
method=method, url=url, timeout=self.timeout, headers=headers, **kwargs
213-
)
214-
duration = time.time() - start
215-
logger.debug(
216-
"method=%s endpoint=%s status_code=%s duration=%f",
217-
method,
218-
endpoint,
219-
response.status_code,
220-
duration,
221-
)
226+
while True:
227+
start = time.time()
228+
response: Response = self.session.request(
229+
method=method, url=url, timeout=self.timeout, headers=headers, **kwargs
230+
)
231+
duration = time.time() - start
232+
logger.debug(
233+
"method=%s endpoint=%s status_code=%s duration=%f",
234+
method,
235+
endpoint,
236+
response.status_code,
237+
duration,
238+
)
239+
if response.status_code == HTTP_TOO_MANY_REQUESTS:
240+
logger.warning("Rate-limit hit")
241+
try:
242+
delay = int(response.headers["Retry-After"])
243+
except (ValueError, KeyError):
244+
# We failed to parse the Retry-After header, return the response as
245+
# is so the caller handles it as an error
246+
logger.error("Could not get the retry-after value")
247+
return response
248+
249+
if self.callbacks:
250+
self.callbacks.on_rate_limited(delay)
251+
logger.warning("Waiting for %d seconds before retrying", delay)
252+
time.sleep(delay)
253+
else:
254+
break
222255

223256
self.app_version = response.headers.get("X-App-Version", self.app_version)
224257
self.secrets_engine_version = response.headers.get(

tests/conftest.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22
from os.path import dirname, join, realpath
3+
from typing import Any
34

45
import pytest
56
import vcr
@@ -19,7 +20,12 @@
1920
)
2021

2122

23+
def create_client(**kwargs: Any) -> GGClient:
24+
"""Create a GGClient using $GITGUARDIAN_API_KEY"""
25+
api_key = os.environ["GITGUARDIAN_API_KEY"]
26+
return GGClient(api_key=api_key, **kwargs)
27+
28+
2229
@pytest.fixture
2330
def client():
24-
api_key = os.environ["GITGUARDIAN_API_KEY"]
25-
return GGClient(api_key=api_key)
31+
return create_client()

tests/test_client.py

Lines changed: 72 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,15 @@
55
from datetime import date
66
from io import BytesIO
77
from typing import Any, Dict, List, Optional, Tuple, Type
8-
from unittest.mock import patch
8+
from unittest.mock import Mock, patch
99

1010
import pytest
1111
import responses
1212
from marshmallow import ValidationError
1313
from responses import matchers
1414

1515
from pygitguardian import GGClient
16-
from pygitguardian.client import is_ok, load_detail
16+
from pygitguardian.client import GGClientCallbacks, is_ok, load_detail
1717
from pygitguardian.config import (
1818
DEFAULT_BASE_URI,
1919
DOCUMENT_SIZE_THRESHOLD_BYTES,
@@ -36,7 +36,7 @@
3636
SCAVulnerability,
3737
)
3838

39-
from .conftest import my_vcr
39+
from .conftest import create_client, my_vcr
4040

4141

4242
FILENAME = ".env"
@@ -612,6 +612,75 @@ def test_multiscan_parameters(
612612
assert mock_response.call_count == 1
613613

614614

615+
@responses.activate
616+
def test_rate_limit():
617+
"""
618+
GIVEN a GGClient instance with callbacks
619+
WHEN calling an API endpoint and we hit a rate-limit
620+
THEN the client retries after the delay
621+
AND the `on_rate_limited()` method of the callback is called
622+
"""
623+
callbacks = Mock(spec=GGClientCallbacks)
624+
625+
client = create_client(callbacks=callbacks)
626+
multiscan_url = client._url_from_endpoint("multiscan", "v1")
627+
628+
rate_limit_response = responses.post(
629+
url=multiscan_url,
630+
status=429,
631+
headers={"Retry-After": "1"},
632+
)
633+
normal_response = responses.post(
634+
url=multiscan_url,
635+
status=200,
636+
json=[
637+
{
638+
"policy_break_count": 0,
639+
"policies": ["pol"],
640+
"policy_breaks": [],
641+
}
642+
],
643+
)
644+
645+
result = client.multi_content_scan(
646+
[{"filename": FILENAME, "document": DOCUMENT}],
647+
)
648+
649+
assert rate_limit_response.call_count == 1
650+
assert normal_response.call_count == 1
651+
assert isinstance(result, MultiScanResult)
652+
callbacks.on_rate_limited.assert_called_once_with(1)
653+
654+
655+
@responses.activate
656+
def test_bogus_rate_limit():
657+
"""
658+
GIVEN a GGClient instance with callbacks
659+
WHEN calling an API endpoint and we hit a rate-limit
660+
AND we can't parse the rate-limit value
661+
THEN the client just returns the error
662+
AND the `on_rate_limited()` method of the callback is not called
663+
"""
664+
callbacks = Mock(spec=GGClientCallbacks)
665+
666+
client = create_client(callbacks=callbacks)
667+
multiscan_url = client._url_from_endpoint("multiscan", "v1")
668+
669+
rate_limit_response = responses.post(
670+
url=multiscan_url,
671+
status=429,
672+
headers={"Retry-After": "later"},
673+
)
674+
675+
result = client.multi_content_scan(
676+
[{"filename": FILENAME, "document": DOCUMENT}],
677+
)
678+
679+
assert rate_limit_response.call_count == 1
680+
assert isinstance(result, Detail)
681+
callbacks.on_rate_limited.assert_not_called()
682+
683+
615684
def test_quota_overview(client: GGClient):
616685
with my_vcr.use_cassette("quota.yaml"):
617686
quota_response = client.quota_overview()

0 commit comments

Comments
 (0)