Skip to content

Commit a094b2b

Browse files
SofienMJguer
authored andcommitted
feat(GGClient): add extra headers option
1 parent f1379ea commit a094b2b

File tree

2 files changed

+137
-18
lines changed

2 files changed

+137
-18
lines changed

pygitguardian/client.py

Lines changed: 48 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ class GGClient:
4949
base_uri: str
5050
timeout: Optional[float]
5151
user_agent: str
52+
extra_headers: Dict
5253

5354
def __init__(
5455
self,
@@ -92,41 +93,63 @@ def __init__(
9293
{
9394
"User-Agent": self.user_agent,
9495
"Authorization": "Token {0}".format(api_key),
95-
}
96+
},
9697
)
9798

9899
def request(
99100
self,
100101
method: str,
101102
endpoint: str,
102103
version: Optional[str] = DEFAULT_API_VERSION,
104+
extra_headers: Dict[str, str] = None,
103105
**kwargs
104106
) -> Response:
105107
if version:
106108
endpoint = urllib.parse.urljoin(version + "/", endpoint)
107109

108110
url = urllib.parse.urljoin(self.base_uri, endpoint)
109111

112+
headers = (
113+
{**self.session.headers, **extra_headers}
114+
if extra_headers
115+
else self.session.headers
116+
)
110117
return self.session.request(
111-
method=method, url=url, timeout=self.timeout, **kwargs
118+
method=method, url=url, timeout=self.timeout, headers=headers, **kwargs
119+
)
120+
121+
def get(
122+
self,
123+
endpoint: str,
124+
version: Optional[str] = DEFAULT_API_VERSION,
125+
extra_headers: Optional[Dict[str, str]] = None,
126+
**kwargs
127+
) -> Response:
128+
return self.request(
129+
method="get",
130+
endpoint=endpoint,
131+
version=version,
132+
extra_headers=extra_headers,
133+
**kwargs,
112134
)
113135

114136
def post(
115137
self,
116138
endpoint: str,
117139
data: str = None,
118140
version: str = DEFAULT_API_VERSION,
141+
extra_headers: Optional[Dict[str, str]] = None,
119142
**kwargs
120143
) -> Response:
121144
return self.request(
122-
"post", endpoint=endpoint, json=data, version=version, **kwargs
145+
"post",
146+
endpoint=endpoint,
147+
json=data,
148+
version=version,
149+
extra_headers=extra_headers,
150+
**kwargs,
123151
)
124152

125-
def get(
126-
self, endpoint: str, version: Optional[str] = DEFAULT_API_VERSION, **kwargs
127-
) -> Response:
128-
return self.request(method="get", endpoint=endpoint, version=version, **kwargs)
129-
130153
def health_check(self) -> Detail:
131154
"""
132155
health_check handles the /health endpoint of the API
@@ -144,13 +167,17 @@ def health_check(self) -> Detail:
144167
return obj
145168

146169
def content_scan(
147-
self, document: str, filename: Optional[str] = None
170+
self,
171+
document: str,
172+
filename: Optional[str] = None,
173+
extra_headers: Optional[Dict[str, str]] = None,
148174
) -> Union[Detail, ScanResult]:
149175
"""
150176
content_scan handles the /scan endpoint of the API
151177
152178
:param filename: name of file, example: "intro.py"
153179
:param document: content of file
180+
:param extra_headers: additional headers to add to the request
154181
:return: Detail or ScanResult response and status code
155182
"""
156183

@@ -160,7 +187,11 @@ def content_scan(
160187

161188
request_obj = Document.SCHEMA.load(doc_dict)
162189

163-
resp = self.post(endpoint="scan", data=request_obj)
190+
resp = self.post(
191+
endpoint="scan",
192+
data=request_obj,
193+
extra_headers=extra_headers,
194+
)
164195
if is_ok(resp):
165196
obj = ScanResult.SCHEMA.load(resp.json())
166197
else:
@@ -173,13 +204,15 @@ def content_scan(
173204
def multi_content_scan(
174205
self,
175206
documents: List[Dict[str, str]],
207+
extra_headers: Optional[Dict[str, str]] = None,
176208
) -> Union[Detail, MultiScanResult]:
177209
"""
178210
multi_content_scan handles the /multiscan endpoint of the API
179211
180212
:param documents: List of dictionaries containing the keys document
181213
and, optionally, filename.
182214
example: [{"document":"example content","filename":"intro.py"}]
215+
:param extra_headers: additional headers to add to the request
183216
:return: Detail or ScanResult response and status code
184217
"""
185218
if len(documents) > MULTI_DOCUMENT_LIMIT:
@@ -194,7 +227,11 @@ def multi_content_scan(
194227
else:
195228
raise TypeError("each document must be a dict")
196229

197-
resp = self.post(endpoint="multiscan", data=request_obj)
230+
resp = self.post(
231+
endpoint="multiscan",
232+
data=request_obj,
233+
extra_headers=extra_headers,
234+
)
198235

199236
if is_ok(resp):
200237
obj = MultiScanResult.SCHEMA.load(dict(scan_results=resp.json()))

tests/test_client.py

Lines changed: 89 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import json
2-
from typing import Dict, List, Type
3-
from unittest.mock import patch
2+
from typing import Dict, List, Optional, Type
3+
from unittest.mock import Mock, patch
44

55
import pytest
66
from marshmallow import ValidationError
7+
from requests.models import Response
78

89
from pygitguardian import GGClient
910
from pygitguardian.client import is_ok, load_detail
@@ -143,7 +144,7 @@
143144

144145

145146
@pytest.mark.parametrize(
146-
"api_key, uri, user_agent, timeout, exception",
147+
"api_key, uri, user_agent, timeout, exception ",
147148
[
148149
pytest.param(
149150
"validapi_keyforsure",
@@ -181,19 +182,37 @@
181182
ValueError,
182183
id="invalid prefix",
183184
),
185+
pytest.param(
186+
"validapi_keyforsure",
187+
"https://api.gitguardian.com/",
188+
"custom",
189+
30.0,
190+
None,
191+
id="Custom headers",
192+
),
184193
],
185194
)
186195
def test_client_creation(
187-
api_key: str, uri: str, user_agent: str, timeout: float, exception: Type[Exception]
196+
api_key: str,
197+
uri: str,
198+
user_agent: str,
199+
timeout: float,
200+
exception: Type[Exception],
188201
):
189202
if exception is not None:
190203
with pytest.raises(exception):
191204
client = GGClient(
192-
api_key=api_key, base_uri=uri, user_agent=user_agent, timeout=timeout
205+
api_key=api_key,
206+
base_uri=uri,
207+
user_agent=user_agent,
208+
timeout=timeout,
193209
)
194210
else:
195211
client = GGClient(
196-
base_uri=uri, api_key=api_key, user_agent=user_agent, timeout=timeout
212+
base_uri=uri,
213+
api_key=api_key,
214+
user_agent=user_agent,
215+
timeout=timeout,
197216
)
198217

199218
if exception is None:
@@ -202,8 +221,8 @@ def test_client_creation(
202221
else:
203222
assert client.base_uri == DEFAULT_BASE_URI
204223
assert client.api_key == api_key
205-
assert user_agent in client.session.headers["User-Agent"]
206224
assert client.timeout == timeout
225+
assert user_agent in client.session.headers["User-Agent"]
207226
assert client.session.headers["Authorization"] == "Token {0}".format(api_key)
208227

209228

@@ -416,3 +435,66 @@ def test_assert_content_type(client: GGClient):
416435
assert obj.status_code == 200
417436
assert isinstance(obj, Detail)
418437
assert str(obj).startswith("200:"), str(obj)
438+
439+
440+
@pytest.mark.parametrize(
441+
"session_headers, extra_headers, expected_headers",
442+
[
443+
pytest.param(
444+
{"session-header": "value"},
445+
None,
446+
{"session-header": "value"},
447+
id="no-additional-headers",
448+
),
449+
pytest.param(
450+
{"session-header": "value"},
451+
{"additional-header": "value"},
452+
{"session-header": "value", "additional-header": "value"},
453+
id="additional-headers",
454+
),
455+
pytest.param(
456+
{"session-header": "value", "common-header": "session-value"},
457+
{"additional-header": "value", "common-header": "add-value"},
458+
{
459+
"session-header": "value",
460+
"additional-header": "value",
461+
"common-header": "add-value",
462+
},
463+
id="priority-additional-headers",
464+
),
465+
],
466+
)
467+
@patch("requests.Session.request")
468+
@my_vcr.use_cassette
469+
def test_extra_headers(
470+
request_mock: Mock,
471+
client: GGClient,
472+
session_headers: Dict[str, str],
473+
extra_headers: Optional[Dict[str, str]],
474+
expected_headers: Dict[str, str],
475+
):
476+
"""
477+
GIVEN client's session headers
478+
WHEN calling any client method with additional headers
479+
THEN session/method headers should be merged with priority on method headers
480+
"""
481+
client.session.headers = session_headers
482+
483+
mock_response = Mock(spec=Response)
484+
mock_response.headers = {"content-type": "text"}
485+
mock_response.text = "some error"
486+
mock_response.status_code = 400
487+
request_mock.return_value = mock_response
488+
489+
client.multi_content_scan(
490+
[{"filename": FILENAME, "document": DOCUMENT}],
491+
extra_headers=extra_headers,
492+
)
493+
assert request_mock.called
494+
_, kwargs = request_mock.call_args
495+
assert expected_headers == kwargs["headers"]
496+
497+
client.content_scan("some_string", extra_headers=extra_headers)
498+
assert request_mock.called
499+
_, kwargs = request_mock.call_args
500+
assert expected_headers == kwargs["headers"]

0 commit comments

Comments
 (0)