Skip to content

Commit 7161554

Browse files
jguerreiroJguer
authored andcommitted
fix(pygitguardian): accept non-json content and load as detail
1 parent 170c129 commit 7161554

File tree

3 files changed

+60
-24
lines changed

3 files changed

+60
-24
lines changed

pygitguardian/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from .client import GGClient
33

44

5-
__version__ = "1.1.1"
5+
__version__ = "1.1.2"
66
GGClient._version = __version__
77

88
__all__ = [

pygitguardian/client.py

Lines changed: 38 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import urllib.parse
33
from typing import Dict, List, Optional, Union
44

5-
import requests
65
from requests import Response, Session, codes
76

87
from .config import (
@@ -14,9 +13,38 @@
1413
from .models import Detail, Document, MultiScanResult, ScanResult
1514

1615

16+
def load_detail(resp: Response) -> Detail:
17+
"""
18+
load_detail loads a Detail from a response
19+
be it JSON or html.
20+
21+
:param resp: API response
22+
:type resp: Response
23+
:return: detail object of response
24+
:rtype: Detail
25+
"""
26+
if resp.headers["content-type"] == "application/json":
27+
data = resp.json()
28+
else:
29+
data = {"detail": resp.text}
30+
31+
return Detail.SCHEMA.load(data)
32+
33+
34+
def is_ok(resp: Response) -> bool:
35+
"""
36+
is_ok returns True is the API responded with 200
37+
and the content type is JSON.
38+
"""
39+
return (
40+
resp.headers["content-type"] == "application/json"
41+
and resp.status_code == codes.ok
42+
)
43+
44+
1745
class GGClient:
1846
_version = "undefined"
19-
session: requests.Session
47+
session: Session
2048
api_key: str
2149
base_uri: str
2250
timeout: Optional[float]
@@ -26,7 +54,7 @@ def __init__(
2654
self,
2755
api_key: str,
2856
base_uri: Optional[str] = None,
29-
session: Optional[requests.Session] = None,
57+
session: Optional[Session] = None,
3058
user_agent: Optional[str] = None,
3159
timeout: Optional[float] = DEFAULT_TIMEOUT,
3260
):
@@ -51,7 +79,7 @@ def __init__(
5179

5280
self.base_uri = base_uri
5381
self.api_key = api_key
54-
self.session = session if isinstance(session, Session) else requests.Session()
82+
self.session = session if isinstance(session, Session) else Session()
5583
self.timeout = timeout
5684
self.user_agent = "pygitguardian/{0} ({1};py{2})".format(
5785
self._version, platform.system(), platform.python_version()
@@ -79,15 +107,10 @@ def request(
79107

80108
url = urllib.parse.urljoin(self.base_uri, endpoint)
81109

82-
resp = self.session.request(
110+
return self.session.request(
83111
method=method, url=url, timeout=self.timeout, **kwargs
84112
)
85113

86-
if resp.headers["content-type"] != "application/json":
87-
raise TypeError("Response is not JSON")
88-
89-
return resp
90-
91114
def post(
92115
self,
93116
endpoint: str,
@@ -115,7 +138,7 @@ def health_check(self) -> Detail:
115138
"""
116139
resp = self.get(endpoint="health")
117140

118-
obj = Detail.SCHEMA.load(resp.json())
141+
obj = load_detail(resp)
119142
obj.status_code = resp.status_code
120143

121144
return obj
@@ -138,10 +161,10 @@ def content_scan(
138161
request_obj = Document.SCHEMA.load(doc_dict)
139162

140163
resp = self.post(endpoint="scan", data=request_obj)
141-
if resp.status_code == codes.ok:
164+
if is_ok(resp):
142165
obj = ScanResult.SCHEMA.load(resp.json())
143166
else:
144-
obj = Detail.SCHEMA.load(resp.json())
167+
obj = load_detail(resp)
145168

146169
obj.status_code = resp.status_code
147170

@@ -173,10 +196,10 @@ def multi_content_scan(
173196

174197
resp = self.post(endpoint="multiscan", data=request_obj)
175198

176-
if resp.status_code == codes.ok:
199+
if is_ok(resp):
177200
obj = MultiScanResult.SCHEMA.load(dict(scan_results=resp.json()))
178201
else:
179-
obj = Detail.SCHEMA.load(resp.json())
202+
obj = load_detail(resp)
180203

181204
obj.status_code = resp.status_code
182205

tests/test_client.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import json
2-
from typing import Dict, List
2+
from typing import Dict, List, Type
33
from unittest.mock import patch
44

55
import pytest
66
from marshmallow import ValidationError
77

88
from pygitguardian import GGClient
9+
from pygitguardian.client import is_ok, load_detail
910
from pygitguardian.config import (
1011
DEFAULT_BASE_URI,
1112
DOCUMENT_SIZE_THRESHOLD_BYTES,
@@ -183,7 +184,7 @@
183184
],
184185
)
185186
def test_client_creation(
186-
api_key: str, uri: str, user_agent: str, timeout: float, exception: Exception
187+
api_key: str, uri: str, user_agent: str, timeout: float, exception: Type[Exception]
187188
):
188189
if exception is not None:
189190
with pytest.raises(exception):
@@ -299,7 +300,7 @@ def test_multi_content_scan(
299300
],
300301
)
301302
def test_content_scan_exceptions(
302-
client: GGClient, to_scan: str, exception: Exception, regex: str
303+
client: GGClient, to_scan: str, exception: Type[Exception], regex: str
303304
):
304305
with pytest.raises(exception, match=regex):
305306
client.content_scan(to_scan)
@@ -313,7 +314,7 @@ def test_content_scan_exceptions(
313314
],
314315
)
315316
def test_multi_content_exceptions(
316-
client: GGClient, to_scan: List, exception: Exception
317+
client: GGClient, to_scan: List, exception: Type[Exception]
317318
):
318319
with pytest.raises(exception):
319320
client.multi_content_scan(to_scan)
@@ -326,7 +327,7 @@ def test_multi_content_not_ok():
326327

327328
obj = client.multi_content_scan(req)
328329

329-
assert obj.status_code, 401
330+
assert obj.status_code == 401
330331
assert isinstance(obj, Detail)
331332
assert obj.detail == "Invalid API key."
332333

@@ -338,7 +339,7 @@ def test_content_not_ok():
338339

339340
obj = client.content_scan(**req)
340341

341-
assert obj.status_code, 401
342+
assert obj.status_code == 401
342343
assert isinstance(obj, Detail)
343344
assert obj.detail == "Invalid API key."
344345

@@ -401,5 +402,17 @@ def test_content_scan(
401402

402403
@my_vcr.use_cassette
403404
def test_assert_content_type(client: GGClient):
404-
with pytest.raises(TypeError):
405-
client.get(endpoint="/docs/static/logo.png", version=None)
405+
"""
406+
GIVEN a response that's 200 but the content is not JSON
407+
WHEN is_ok is called
408+
THEN is_ok should be false
409+
WHEN load_detail is called
410+
THEN is should return a Detail object
411+
"""
412+
resp = client.get(endpoint="/docs/static/logo.png", version=None)
413+
assert is_ok(resp) is False
414+
obj = load_detail(resp)
415+
obj.status_code = resp.status_code
416+
assert obj.status_code == 200
417+
assert isinstance(obj, Detail)
418+
assert str(obj).startswith("200:"), str(obj)

0 commit comments

Comments
 (0)