Skip to content

Commit c180fea

Browse files
authored
Merge pull request #61 from GitGuardian/agateau/pyright
Switch to pyright, complete type-hinting
2 parents 9020a7f + c218cf4 commit c180fea

File tree

10 files changed

+103
-79
lines changed

10 files changed

+103
-79
lines changed

.github/workflows/test-lint.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ jobs:
2525
run: |
2626
python -m pip install --upgrade pip
2727
python -m pip install --upgrade pipenv==2022.10.4 pre-commit
28-
pipenv install --system --skip-lock
28+
pipenv install --dev --skip-lock
2929
3030
- uses: actions/cache@v3
3131
with:

.pre-commit-config.yaml

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,17 @@ repos:
1212
hooks:
1313
- id: flake8
1414

15-
- repo: https://github.com/pre-commit/mirrors-mypy
16-
rev: v0.961
17-
hooks:
18-
- id: mypy
19-
additional_dependencies: [types-requests]
15+
# use a "local" repo and not the pyright hook to ensure pyright runs in the same virtualenv
16+
# as the rest of the code
17+
- repo: local
18+
hooks:
19+
- id: pyright
20+
name: pyright
21+
entry: 'pipenv run pyright'
22+
language: system
23+
types: [python]
24+
# do not pass filenames, otherwise Pyright might scan files we don't want it to scan
25+
pass_filenames: false
2026

2127
- repo: https://github.com/pre-commit/pre-commit-hooks
2228
rev: v4.3.0

MANIFEST.in

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
include README.md
22
include LICENSE
3+
include pygitguardian/py.typed

Pipfile

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ pre-commit = "*"
1616
pytest = "*"
1717
vcrpy = ">=4.3.0,!=4.3.1,<4.4.0" # v4.3.1 broke decode_compressed_response
1818
urllib3 = "<2" # pin until https://github.com/kevin1024/vcrpy/issues/688 is fixed
19-
mypy = "==0.961"
20-
types-requests = "*"
2119
scriv = { version = "*", extras = ["toml"] }
2220
responses = ">=0.23.1,<0.24.0"
21+
pyright = "==1.1.313"

pygitguardian/client.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,11 @@
1212
from requests import Response, Session, codes
1313

1414
from .config import DEFAULT_API_VERSION, DEFAULT_BASE_URI, DEFAULT_TIMEOUT
15-
from .iac_models import (
16-
IaCScanParameters,
17-
IaCScanParametersSchema,
18-
IaCScanResult,
19-
IaCScanResultSchema,
20-
)
15+
from .iac_models import IaCScanParameters, IaCScanParametersSchema, IaCScanResult
2116
from .models import (
2217
Detail,
2318
Document,
19+
DocumentSchema,
2420
HealthCheckResponse,
2521
HoneytokenResponse,
2622
MultiScanResult,
@@ -69,7 +65,7 @@ def load_detail(resp: Response) -> Detail:
6965
else:
7066
data = {"detail": resp.text}
7167

72-
return cast(Detail, Detail.SCHEMA.load(data))
68+
return Detail.from_dict(data)
7369

7470

7571
def is_ok(resp: Response) -> bool:
@@ -266,7 +262,7 @@ def get(
266262
def post(
267263
self,
268264
endpoint: str,
269-
data: Optional[Dict[str, Any]] = None,
265+
data: Union[Dict[str, Any], List[Dict[str, Any]], None] = None,
270266
version: str = DEFAULT_API_VERSION,
271267
extra_headers: Optional[Dict[str, str]] = None,
272268
**kwargs: Any,
@@ -320,8 +316,8 @@ def content_scan(
320316
if filename:
321317
doc_dict["filename"] = filename
322318

323-
request_obj = Document.SCHEMA.load(doc_dict)
324-
Document.SCHEMA.validate_size(
319+
request_obj = cast(Dict[str, Any], Document.SCHEMA.load(doc_dict))
320+
DocumentSchema.validate_size(
325321
request_obj, self.secret_scan_preferences.maximum_document_size
326322
)
327323

@@ -333,7 +329,7 @@ def content_scan(
333329

334330
obj: Union[Detail, ScanResult]
335331
if is_ok(resp):
336-
obj = ScanResult.SCHEMA.load(resp.json())
332+
obj = ScanResult.from_dict(resp.json())
337333
else:
338334
obj = load_detail(resp)
339335

@@ -367,12 +363,14 @@ def multi_content_scan(
367363
)
368364

369365
if all(isinstance(doc, dict) for doc in documents):
370-
request_obj = Document.SCHEMA.load(documents, many=True)
366+
request_obj = cast(
367+
List[Dict[str, Any]], Document.SCHEMA.load(documents, many=True)
368+
)
371369
else:
372370
raise TypeError("each document must be a dict")
373371

374372
for document in request_obj:
375-
Document.SCHEMA.validate_size(
373+
DocumentSchema.validate_size(
376374
document, self.secret_scan_preferences.maximum_document_size
377375
)
378376

@@ -390,7 +388,7 @@ def multi_content_scan(
390388

391389
obj: Union[Detail, MultiScanResult]
392390
if is_ok(resp):
393-
obj = MultiScanResult.SCHEMA.load(dict(scan_results=resp.json()))
391+
obj = MultiScanResult.from_dict({"scan_results": resp.json()})
394392
else:
395393
obj = load_detail(resp)
396394

@@ -416,7 +414,7 @@ def quota_overview(
416414

417415
obj: Union[Detail, QuotaResponse]
418416
if is_ok(resp):
419-
obj = QuotaResponse.SCHEMA.load(resp.json())
417+
obj = QuotaResponse.from_dict(resp.json())
420418
else:
421419
obj = load_detail(resp)
422420

@@ -455,7 +453,7 @@ def create_honeytoken(
455453
result.status_code = 504
456454
else:
457455
if is_create_ok(resp):
458-
result = HoneytokenResponse.SCHEMA.load(resp.json())
456+
result = HoneytokenResponse.from_dict(resp.json())
459457
else:
460458
result = load_detail(resp)
461459
result.status_code = resp.status_code
@@ -487,7 +485,7 @@ def iac_directory_scan(
487485
result.status_code = 504
488486
else:
489487
if is_ok(resp):
490-
result = IaCScanResultSchema().load(resp.json())
488+
result = IaCScanResult.from_dict(resp.json())
491489
else:
492490
result = load_detail(resp)
493491

@@ -510,7 +508,7 @@ def read_metadata(self) -> Optional[Detail]:
510508
result = load_detail(resp)
511509
result.status_code = resp.status_code
512510
return result
513-
metadata = ServerMetadata.SCHEMA.load(resp.json())
511+
metadata = ServerMetadata.from_dict(resp.json())
514512

515513
self.secret_scan_preferences = metadata.secret_scan_preferences
516514
return None

pygitguardian/iac_models.py

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
from dataclasses import dataclass, field
2-
from typing import List, Optional
2+
from typing import List, Optional, Type, cast
33

44
import marshmallow_dataclass
55

6-
from pygitguardian.models import Base, BaseSchema
6+
from pygitguardian.models import Base, BaseSchema, FromDictMixin
77

88

99
@dataclass
10-
class IaCVulnerability(Base):
10+
class IaCVulnerability(Base, FromDictMixin):
1111
policy: str
1212
policy_id: str
1313
line_end: int
@@ -18,37 +18,48 @@ class IaCVulnerability(Base):
1818
severity: str = ""
1919

2020

21-
IaCVulnerabilitySchema = marshmallow_dataclass.class_schema(
22-
IaCVulnerability, BaseSchema
21+
IaCVulnerabilitySchema = cast(
22+
Type[BaseSchema], marshmallow_dataclass.class_schema(IaCVulnerability, BaseSchema)
2323
)
2424

25+
IaCVulnerability.SCHEMA = IaCVulnerabilitySchema()
26+
2527

2628
@dataclass
27-
class IaCFileResult(Base):
29+
class IaCFileResult(Base, FromDictMixin):
2830
filename: str
2931
incidents: List[IaCVulnerability]
3032

3133

32-
IaCFileResultSchema = marshmallow_dataclass.class_schema(IaCFileResult, BaseSchema)
34+
IaCFileResultSchema = cast(
35+
Type[BaseSchema], marshmallow_dataclass.class_schema(IaCFileResult, BaseSchema)
36+
)
37+
38+
IaCFileResult.SCHEMA = IaCFileResultSchema()
3339

3440

3541
@dataclass
36-
class IaCScanParameters(Base):
42+
class IaCScanParameters(Base, FromDictMixin):
3743
ignored_policies: List[str] = field(default_factory=list)
3844
minimum_severity: Optional[str] = None
3945

4046

41-
IaCScanParametersSchema = marshmallow_dataclass.class_schema(
42-
IaCScanParameters, BaseSchema
47+
IaCScanParametersSchema = cast(
48+
Type[BaseSchema], marshmallow_dataclass.class_schema(IaCScanParameters, BaseSchema)
4349
)
4450

51+
IaCScanParameters.SCHEMA = IaCScanParametersSchema()
52+
4553

4654
@dataclass
47-
class IaCScanResult(Base):
55+
class IaCScanResult(Base, FromDictMixin):
4856
id: str = ""
4957
type: str = ""
5058
iac_engine_version: str = ""
5159
entities_with_incidents: List[IaCFileResult] = field(default_factory=list)
5260

5361

54-
IaCScanResultSchema = marshmallow_dataclass.class_schema(IaCScanResult, BaseSchema)
62+
IaCScanResultSchema = cast(
63+
Type[BaseSchema], marshmallow_dataclass.class_schema(IaCScanResult, BaseSchema)
64+
)
65+
IaCScanResult.SCHEMA = IaCScanResultSchema()

pygitguardian/models.py

Lines changed: 45 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,48 @@
1313
pre_load,
1414
validate,
1515
)
16+
from typing_extensions import Self
1617

1718
from .config import DOCUMENT_SIZE_THRESHOLD_BYTES, MULTI_DOCUMENT_LIMIT
1819

1920

21+
class ToDictMixin:
22+
"""
23+
Provides a type-safe `to_dict()` method for classes using Marshmallow
24+
"""
25+
26+
SCHEMA: ClassVar[Schema]
27+
28+
def to_dict(self) -> Dict[str, Any]:
29+
return cast(Dict[str, Any], self.SCHEMA.dump(self))
30+
31+
32+
class FromDictMixin:
33+
"""This class must be used as an additional base class for all classes whose schema
34+
implements a `post_load` function turning the received dict into a class instance.
35+
36+
It makes it possible to deserialize an object using `MyClass.from_dict(dct)` instead
37+
of `MyClass.SCHEMA.load(dct)`. The `from_dict()` method is shorter, but more
38+
importantly, type-safe: its return type is an instance of `MyClass`, not
39+
`list[Any] | Any`.
40+
41+
Reference: https://marshmallow.readthedocs.io/en/stable/quickstart.html#deserializing-to-objects E501
42+
"""
43+
44+
SCHEMA: ClassVar[Schema]
45+
46+
@classmethod
47+
def from_dict(cls, dct: Dict[str, Any]) -> Self:
48+
return cast(Self, cls.SCHEMA.load(dct))
49+
50+
2051
class BaseSchema(Schema):
2152
class Meta:
2253
ordered = True
2354
unknown = EXCLUDE
2455

2556

26-
class Base:
27-
SCHEMA: ClassVar[BaseSchema]
28-
57+
class Base(ToDictMixin):
2958
def __init__(self, status_code: Optional[int] = None) -> None:
3059
self.status_code = status_code
3160

@@ -35,12 +64,6 @@ def to_json(self) -> str:
3564
"""
3665
return cast(str, self.SCHEMA.dumps(self))
3766

38-
def to_dict(self) -> Dict:
39-
"""
40-
to_dict converts model to a dictionary representation.
41-
"""
42-
return cast(Dict, self.SCHEMA.dump(self))
43-
4467
@property
4568
def success(self) -> bool:
4669
return self.__bool__()
@@ -122,7 +145,7 @@ def make_detail_response(self, data: Dict[str, Any], **kwargs: Any) -> "Detail":
122145
return Detail(**data)
123146

124147

125-
class Detail(Base):
148+
class Detail(Base, FromDictMixin):
126149
"""Detail is a response object mostly returned on error or when the
127150
api output is a simple string.
128151
@@ -155,7 +178,7 @@ def make_match(self, data: Dict[str, Any], **kwargs: Any) -> "Match":
155178
return Match(**data)
156179

157180

158-
class Match(Base):
181+
class Match(Base, FromDictMixin):
159182
"""
160183
Match describes a found issue by GitGuardian.
161184
With info such as match location and type.
@@ -219,7 +242,7 @@ def make_policy_break(self, data: Dict[str, Any], **kwargs: Any) -> "PolicyBreak
219242
return PolicyBreak(**data)
220243

221244

222-
class PolicyBreak(Base):
245+
class PolicyBreak(Base, FromDictMixin):
223246
"""
224247
PolicyBreak describes a GitGuardian policy break found
225248
in a scan.
@@ -269,7 +292,7 @@ def make_scan_result(self, data: Dict[str, Any], **kwargs: Any) -> "ScanResult":
269292
return ScanResult(**data)
270293

271294

272-
class ScanResult(Base):
295+
class ScanResult(Base, FromDictMixin):
273296
"""ScanResult is a response object returned on a Content Scan
274297
275298
Attributes:
@@ -355,7 +378,7 @@ def make_scan_result(
355378
return MultiScanResult(**data)
356379

357380

358-
class MultiScanResult(Base):
381+
class MultiScanResult(Base, FromDictMixin):
359382
"""ScanResult is a response object returned on a Content Scan
360383
361384
Attributes:
@@ -425,7 +448,7 @@ def make_quota(self, data: Dict[str, Any], **kwargs: Any) -> "Quota":
425448
return Quota(**data)
426449

427450

428-
class Quota(Base):
451+
class Quota(Base, FromDictMixin):
429452
"""
430453
Quota describes a quota category in the GitGuardian API.
431454
Allows you to check your current available quota.
@@ -468,7 +491,7 @@ def make_quota_response(
468491
return QuotaResponse(**data)
469492

470493

471-
class QuotaResponse(Base):
494+
class QuotaResponse(Base, FromDictMixin):
472495
"""
473496
Quota describes a quota category in the GitGuardian API.
474497
Allows you to check your current available quota.
@@ -515,7 +538,7 @@ def make_honeytoken_response(
515538
return HoneytokenResponse(**data)
516539

517540

518-
class HoneytokenResponse(Base):
541+
class HoneytokenResponse(Base, FromDictMixin):
519542
"""
520543
honeytoken creation in the GitGuardian API.
521544
Allows users to create and get a honeytoken.
@@ -633,14 +656,15 @@ class SecretScanPreferences:
633656

634657

635658
@dataclass
636-
class ServerMetadata(Base):
659+
class ServerMetadata(Base, FromDictMixin):
637660
version: str
638661
preferences: Dict[str, Any]
639662
secret_scan_preferences: SecretScanPreferences = field(
640663
default_factory=SecretScanPreferences
641664
)
642665

643666

644-
ServerMetadata.SCHEMA = marshmallow_dataclass.class_schema(
645-
ServerMetadata, base_schema=BaseSchema
646-
)()
667+
ServerMetadata.SCHEMA = cast(
668+
BaseSchema,
669+
marshmallow_dataclass.class_schema(ServerMetadata, base_schema=BaseSchema)(),
670+
)

pygitguardian/py.typed

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# PEP-561 Support File.
2+
# "Package maintainers who wish to support type checking of their code MUST add a marker file named py.typed to their package supporting typing".

0 commit comments

Comments
 (0)