Skip to content

Commit cc7b151

Browse files
author
João Guerreiro
committed
feat(pygitguardian): add multiscan
1 parent 6a98962 commit cc7b151

13 files changed

+769
-190
lines changed

.github/workflows/test-lint.yml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ jobs:
3030
runs-on: ubuntu-latest
3131
strategy:
3232
matrix:
33-
python-version: [3.5, 3.6, 3.7, 3.8]
33+
python-version: [3.6, 3.7, 3.8]
3434

3535
steps:
3636
- uses: actions/checkout@v2
@@ -42,8 +42,7 @@ jobs:
4242
run: |
4343
python -m pip install --upgrade pip
4444
python -m pip install --upgrade pipenv
45-
pipenv install --system --skip-lock
46-
pip install nose coverage
45+
pipenv install --system --dev --skip-lock
4746
- name: Test with pytest
4847
run: |
4948
pipenv run coverage run --source pygitguardian -m nose tests

Pipfile

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,21 @@ verify_ssl = true
44
name = "pypi"
55

66
[packages]
7-
pygitguardian = {editable = true,path = "."}
87
marshmallow = "~=3.5"
8+
pygitguardian = {editable = true,path = "."}
99
requests = ">=2.21.0"
1010

1111
[dev-packages]
1212
black = "==19.10b0"
13-
flake8 = "*"
14-
ipython = "*"
15-
pre-commit = "*"
1613
coverage = "*"
14+
flake8 = "*"
1715
flake8-isort = "*"
16+
ipython = "*"
1817
isort = "*"
1918
nose = "*"
19+
pre-commit = "*"
20+
vcrpy = "*"
21+
vcrpy-unittest = "*"
22+
23+
[pipenv]
24+
allow_prereleases = true

examples/content_scan.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
import sys
33
import traceback
44

5+
from requests import codes
6+
57
from pygitguardian import GGClient
68

79

@@ -17,9 +19,9 @@
1719
client = GGClient(token=API_KEY)
1820

1921
# Check the health of the API and the token used.
20-
health_obj = client.health_check()
22+
health_obj, status = client.health_check()
2123

22-
if health_obj.success:
24+
if status == codes[r"\o/"]: # this is 200 but cooler
2325
try:
2426
scan_result = client.content_scan(filename=FILENAME, document=DOCUMENT)
2527
except Exception as exc:

pygitguardian/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
from .schemas import DetailSchema, DocumentSchema, ScanResultSchema
55

66

7-
__version__ = "1.0.0"
7+
__version__ = "1.0.1"
8+
GGClient._version = __version__
89

910
__all__ = [
1011
"Detail",

pygitguardian/client.py

Lines changed: 63 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1+
import platform
12
import urllib.parse
2-
from typing import Any, Optional, Tuple, Union
3+
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
34

45
import requests
56
from marshmallow import Schema
@@ -18,26 +19,22 @@ class GGClient:
1819
DETAIL_SCHEMA = DetailSchema()
1920
DOCUMENT_SCHEMA = DocumentSchema()
2021
SCAN_RESULT_SCHEMA = ScanResultSchema()
22+
_version = "undefined"
2123

2224
def __init__(
2325
self,
2426
token: str,
25-
base_uri: str = None,
26-
session: requests.Session = None,
27-
user_agent: str = "",
28-
timeout: float = _DEFAULT_TIMEOUT,
29-
):
27+
base_uri: Optional[str] = None,
28+
session: Optional[requests.Session] = None,
29+
user_agent: Optional[str] = None,
30+
timeout: Optional[float] = _DEFAULT_TIMEOUT,
31+
) -> "GGClient":
3032
"""
3133
:param token: APIKey to be added to requests
32-
:type token: str
3334
:param base_uri: Base URI for the API, defaults to "https://api.gitguardian.com"
34-
:type base_uri: str, optional
3535
:param session: custom requests session, defaults to requests.Session()
36-
:type session: requests.Session, optional
3736
:param user_agent: user agent to identify requests, defaults to ""
38-
:type user_agent: str, optional
3937
:param timeout: request timeout, defaults to 20s
40-
:type timeout: float, optional
4138
4239
:raises ValueError: if the protocol is invalid
4340
"""
@@ -57,12 +54,15 @@ def __init__(
5754
session if session is isinstance(session, Session) else requests.Session()
5855
)
5956
self.timeout = timeout
57+
self.user_agent = "pygitguardian/{0} ({1};py{2})".format(
58+
self._version, platform.system(), platform.python_version()
59+
)
60+
61+
if user_agent:
62+
self.user_agent = " ".join([self.user_agent, user_agent])
6063

6164
self.session.headers.update(
62-
{
63-
"User-Agent": " ".join(["pygitguardian", user_agent]),
64-
"Authorization": "Token {0}".format(token),
65-
}
65+
{"User-Agent": self.user_agent, "Authorization": "Token {0}".format(token)}
6666
)
6767

6868
def request(
@@ -71,6 +71,7 @@ def request(
7171
endpoint: str,
7272
schema: Schema = None,
7373
version: str = _API_VERSION,
74+
many: bool = False,
7475
**kwargs
7576
) -> Tuple[Any, Response]:
7677
if version:
@@ -86,11 +87,15 @@ def request(
8687
raise TypeError("Response is not JSON")
8788

8889
if response.status_code == codes.ok and schema:
89-
obj = schema.load(response.json())
90+
obj = schema.load(response.json(), many=many)
91+
if many:
92+
for element in obj:
93+
element.status_code = response.status_code
94+
else:
95+
obj.status_code = response.status_code
9096
else:
9197
obj = self.DETAIL_SCHEMA.load(response.json())
92-
93-
obj.status_code = response.status_code
98+
obj.status_code = response.status_code
9499

95100
return obj, response
96101

@@ -100,6 +105,7 @@ def post(
100105
data: str = None,
101106
schema: Schema = None,
102107
version: str = _API_VERSION,
108+
many: bool = False,
103109
**kwargs
104110
) -> Tuple[Any, Response]:
105111
return self.request(
@@ -108,6 +114,7 @@ def post(
108114
schema=schema,
109115
json=data,
110116
version=version,
117+
many=many,
111118
**kwargs,
112119
)
113120

@@ -116,6 +123,7 @@ def get(
116123
endpoint: str,
117124
schema: Schema = None,
118125
version: str = _API_VERSION,
126+
many: bool = False,
119127
**kwargs
120128
) -> Tuple[Any, Response]:
121129
return self.request(
@@ -124,37 +132,58 @@ def get(
124132

125133
def content_scan(
126134
self, document: str, filename: Optional[str] = None
127-
) -> Union[Detail, ScanResult]:
128-
"""content_scan handles the /scan endpoint of the API
129-
130-
use filename=dummy to avoid evalutation of filename and file extension policies
135+
) -> Tuple[Union[Detail, ScanResult], int]:
136+
"""
137+
content_scan handles the /scan endpoint of the API
131138
132139
:param filename: name of file, example: "intro.py"
133-
:type filename: str
134140
:param document: content of file
135-
:type document: str
136-
:return: Detail or ScanResult response
137-
:rtype: Union[Detail, ScanResult]
141+
:return: Detail or ScanResult response and status code
138142
"""
139143

140144
doc_dict = {"document": document}
141145
if filename:
142146
doc_dict["filename"] = filename
143147

144148
request_obj = self.DOCUMENT_SCHEMA.load(doc_dict)
145-
obj, _ = self.post(
149+
obj, resp = self.post(
146150
endpoint="scan", data=request_obj, schema=self.SCAN_RESULT_SCHEMA
147151
)
148-
return obj
152+
return obj, resp.status_code
153+
154+
def multi_content_scan(
155+
self, documents: Iterable[Dict[str, str]],
156+
) -> Tuple[Union[Detail, List[ScanResult]], int]:
157+
"""
158+
multi_content_scan handles the /multiscan endpoint of the API
159+
160+
:param documents: List of dictionaries containing the keys document
161+
and, optionaly, filename.
162+
example: [{"document":"example content","filename":"intro.py"}]
163+
:return: Detail or ScanResult response and status code
164+
"""
165+
166+
if all(isinstance(doc, dict) for doc in documents):
167+
request_obj = self.DOCUMENT_SCHEMA.load(documents, many=True)
168+
else:
169+
raise TypeError("documents must be a dict")
170+
171+
obj, resp = self.post(
172+
endpoint="multiscan",
173+
data=request_obj,
174+
schema=self.SCAN_RESULT_SCHEMA,
175+
many=True,
176+
)
177+
return obj, resp.status_code
149178

150-
def health_check(self) -> Detail:
151-
"""health_check handles the /health endpoint of the API
179+
def health_check(self) -> Tuple[Detail, int]:
180+
"""
181+
health_check handles the /health endpoint of the API
152182
153183
use Detail.status_code to check the response status code of the API
154184
155185
200 if server is online and token is valid
156-
:return: Detail response,
157-
:rtype: Detail
186+
:return: Detail response and status code
158187
"""
159-
obj, _ = self.get(endpoint="health")
160-
return obj
188+
obj, resp = self.get(endpoint="health")
189+
return obj, resp.status_code

pygitguardian/models.py

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,7 @@
11
from typing import List, Optional
22

33

4-
class BaseObject:
5-
UNSET_STATUS_CODE = 600
6-
7-
def __init__(self):
8-
self.status_code = 600
9-
10-
@property
11-
def success(self):
12-
"""success returns True if call returned 200
13-
14-
:return: call status
15-
:rtype: bool
16-
"""
17-
return self.status_code == 200
18-
19-
20-
class Detail(BaseObject):
4+
class Detail:
215
"""Detail is a response object mostly returned on error or when the
226
api output is a simple string
237
@@ -87,7 +71,7 @@ def __repr__(self):
8771
)
8872

8973

90-
class ScanResult(BaseObject):
74+
class ScanResult:
9175
"""ScanResult is a response object returned on a Content Scan
9276
9377
Attributes:

pygitguardian/schemas.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,44 @@
33
This module contains marshmallow schemas responsible for
44
serializing/deserializing request and response objects
55
"""
6-
from marshmallow import Schema, fields, post_load, validate
6+
from marshmallow import (
7+
EXCLUDE,
8+
Schema,
9+
ValidationError,
10+
fields,
11+
post_load,
12+
validate,
13+
validates,
14+
)
715

816
from .models import Detail, Match, PolicyBreak, ScanResult
917

1018

19+
MB = 1048576
20+
21+
1122
class DocumentSchema(Schema):
12-
filename = fields.Str(validate=validate.Length(max=256))
13-
document = fields.Str(validate=validate.Length(max=1000000), required=True)
23+
class Meta:
24+
unknown = EXCLUDE
25+
26+
filename = fields.String(validate=validate.Length(max=256))
27+
document = fields.String(required=True)
28+
29+
@validates("document")
30+
def validate_document(self, document: str) -> str:
31+
"""
32+
validate that document is smaller than scan limit
33+
"""
34+
encoded = document.encode("utf-8")
35+
if len(encoded) > MB:
36+
raise ValidationError(
37+
"This file exceeds the maximum allowed size of {}B".format(MB)
38+
)
39+
40+
if "\x00" in document:
41+
raise ValidationError("Document has null characters")
42+
43+
return document
1444

1545

1646
class MatchSchema(Schema):

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ max-line-length = 120
77
line_length=88
88
lines_after_imports=2
99
multi_line_output=3
10-
known_third_party=marshmallow,requests,setuptools
10+
known_third_party=marshmallow,requests,setuptools,vcr_unittest
1111
include_trailing_comma=true
1212

1313
[metadata]

0 commit comments

Comments
 (0)