Skip to content

Commit a33136a

Browse files
committed
chore: introduce helper mixins
Introduce FromDictMixin to provide a more type-safe way of deserializing a dict into a class instance, and its reverse ToDictMixin. Make model classes inherit the appropriate mixins. Use the methods provided by the mixins instead of the not-type-safe Marshmallow ones.
1 parent b81efbb commit a33136a

File tree

4 files changed

+81
-48
lines changed

4 files changed

+81
-48
lines changed

pygitguardian/client.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,7 @@
1010
from requests import Response, Session, codes
1111

1212
from .config import DEFAULT_API_VERSION, DEFAULT_BASE_URI, DEFAULT_TIMEOUT
13-
from .iac_models import (
14-
IaCScanParameters,
15-
IaCScanParametersSchema,
16-
IaCScanResult,
17-
IaCScanResultSchema,
18-
)
13+
from .iac_models import IaCScanParameters, IaCScanParametersSchema, IaCScanResult
1914
from .models import (
2015
Detail,
2116
Document,
@@ -65,7 +60,7 @@ def load_detail(resp: Response) -> Detail:
6560
else:
6661
data = {"detail": resp.text}
6762

68-
return cast(Detail, Detail.SCHEMA.load(data))
63+
return Detail.from_dict(data)
6964

7065

7166
def is_ok(resp: Response) -> bool:
@@ -307,7 +302,7 @@ def content_scan(
307302
if filename:
308303
doc_dict["filename"] = filename
309304

310-
request_obj = Document.SCHEMA.load(doc_dict)
305+
request_obj = cast(Dict[str, Any], Document.SCHEMA.load(doc_dict))
311306
DocumentSchema.validate_size(
312307
request_obj, self.secret_scan_preferences.maximum_document_size
313308
)
@@ -320,7 +315,7 @@ def content_scan(
320315

321316
obj: Union[Detail, ScanResult]
322317
if is_ok(resp):
323-
obj = ScanResult.SCHEMA.load(resp.json())
318+
obj = ScanResult.from_dict(resp.json())
324319
else:
325320
obj = load_detail(resp)
326321

@@ -354,7 +349,9 @@ def multi_content_scan(
354349
)
355350

356351
if all(isinstance(doc, dict) for doc in documents):
357-
request_obj = Document.SCHEMA.load(documents, many=True)
352+
request_obj = cast(
353+
List[Dict[str, Any]], Document.SCHEMA.load(documents, many=True)
354+
)
358355
else:
359356
raise TypeError("each document must be a dict")
360357

@@ -377,7 +374,7 @@ def multi_content_scan(
377374

378375
obj: Union[Detail, MultiScanResult]
379376
if is_ok(resp):
380-
obj = MultiScanResult.SCHEMA.load(dict(scan_results=resp.json()))
377+
obj = MultiScanResult.from_dict({"scan_results": resp.json()})
381378
else:
382379
obj = load_detail(resp)
383380

@@ -403,7 +400,7 @@ def quota_overview(
403400

404401
obj: Union[Detail, QuotaResponse]
405402
if is_ok(resp):
406-
obj = QuotaResponse.SCHEMA.load(resp.json())
403+
obj = QuotaResponse.from_dict(resp.json())
407404
else:
408405
obj = load_detail(resp)
409406

@@ -442,7 +439,7 @@ def create_honeytoken(
442439
result.status_code = 504
443440
else:
444441
if is_create_ok(resp):
445-
result = HoneytokenResponse.SCHEMA.load(resp.json())
442+
result = HoneytokenResponse.from_dict(resp.json())
446443
else:
447444
result = load_detail(resp)
448445
result.status_code = resp.status_code
@@ -474,7 +471,7 @@ def iac_directory_scan(
474471
result.status_code = 504
475472
else:
476473
if is_ok(resp):
477-
result = IaCScanResultSchema().load(resp.json())
474+
result = IaCScanResult.from_dict(resp.json())
478475
else:
479476
result = load_detail(resp)
480477

@@ -497,7 +494,7 @@ def read_metadata(self) -> Optional[Detail]:
497494
result = load_detail(resp)
498495
result.status_code = resp.status_code
499496
return result
500-
metadata = ServerMetadata.SCHEMA.load(resp.json())
497+
metadata = ServerMetadata.from_dict(resp.json())
501498

502499
self.secret_scan_preferences = metadata.secret_scan_preferences
503500
return None

pygitguardian/iac_models.py

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
from dataclasses import dataclass, field
2-
from typing import List, Optional
2+
from typing import List, Optional, Type, cast
33

44
import marshmallow_dataclass
55

6-
from pygitguardian.models import Base, BaseSchema
6+
from pygitguardian.models import Base, BaseSchema, FromDictMixin
77

88

99
@dataclass
10-
class IaCVulnerability(Base):
10+
class IaCVulnerability(Base, FromDictMixin):
1111
policy: str
1212
policy_id: str
1313
line_end: int
@@ -18,37 +18,48 @@ class IaCVulnerability(Base):
1818
severity: str = ""
1919

2020

21-
IaCVulnerabilitySchema = marshmallow_dataclass.class_schema(
22-
IaCVulnerability, BaseSchema
21+
IaCVulnerabilitySchema = cast(
22+
Type[BaseSchema], marshmallow_dataclass.class_schema(IaCVulnerability, BaseSchema)
2323
)
2424

25+
IaCVulnerability.SCHEMA = IaCVulnerabilitySchema()
26+
2527

2628
@dataclass
27-
class IaCFileResult(Base):
29+
class IaCFileResult(Base, FromDictMixin):
2830
filename: str
2931
incidents: List[IaCVulnerability]
3032

3133

32-
IaCFileResultSchema = marshmallow_dataclass.class_schema(IaCFileResult, BaseSchema)
34+
IaCFileResultSchema = cast(
35+
Type[BaseSchema], marshmallow_dataclass.class_schema(IaCFileResult, BaseSchema)
36+
)
37+
38+
IaCFileResult.SCHEMA = IaCFileResultSchema()
3339

3440

3541
@dataclass
36-
class IaCScanParameters(Base):
42+
class IaCScanParameters(Base, FromDictMixin):
3743
ignored_policies: List[str] = field(default_factory=list)
3844
minimum_severity: Optional[str] = None
3945

4046

41-
IaCScanParametersSchema = marshmallow_dataclass.class_schema(
42-
IaCScanParameters, BaseSchema
47+
IaCScanParametersSchema = cast(
48+
Type[BaseSchema], marshmallow_dataclass.class_schema(IaCScanParameters, BaseSchema)
4349
)
4450

51+
IaCScanParameters.SCHEMA = IaCScanParametersSchema()
52+
4553

4654
@dataclass
47-
class IaCScanResult(Base):
55+
class IaCScanResult(Base, FromDictMixin):
4856
id: str = ""
4957
type: str = ""
5058
iac_engine_version: str = ""
5159
entities_with_incidents: List[IaCFileResult] = field(default_factory=list)
5260

5361

54-
IaCScanResultSchema = marshmallow_dataclass.class_schema(IaCScanResult, BaseSchema)
62+
IaCScanResultSchema = cast(
63+
Type[BaseSchema], marshmallow_dataclass.class_schema(IaCScanResult, BaseSchema)
64+
)
65+
IaCScanResult.SCHEMA = IaCScanResultSchema()

pygitguardian/models.py

Lines changed: 45 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,48 @@
1313
pre_load,
1414
validate,
1515
)
16+
from typing_extensions import Self
1617

1718
from .config import DOCUMENT_SIZE_THRESHOLD_BYTES, MULTI_DOCUMENT_LIMIT
1819

1920

21+
class ToDictMixin:
22+
"""
23+
Provides a type-safe `to_dict()` method for classes using Marshmallow
24+
"""
25+
26+
SCHEMA: ClassVar[Schema]
27+
28+
def to_dict(self) -> Dict[str, Any]:
29+
return cast(Dict[str, Any], self.SCHEMA.dump(self))
30+
31+
32+
class FromDictMixin:
33+
"""This class must be used as an additional base class for all classes whose schema
34+
implements a `post_load` function turning the received dict into a class instance.
35+
36+
It makes it possible to deserialize an object using `MyClass.from_dict(dct)` instead
37+
of `MyClass.SCHEMA.load(dct)`. The `from_dict()` method is shorter, but more
38+
importantly, type-safe: its return type is an instance of `MyClass`, not
39+
`list[Any] | Any`.
40+
41+
Reference: https://marshmallow.readthedocs.io/en/stable/quickstart.html#deserializing-to-objects E501
42+
"""
43+
44+
SCHEMA: ClassVar[Schema]
45+
46+
@classmethod
47+
def from_dict(cls, dct: Dict[str, Any]) -> Self:
48+
return cast(Self, cls.SCHEMA.load(dct))
49+
50+
2051
class BaseSchema(Schema):
2152
class Meta:
2253
ordered = True
2354
unknown = EXCLUDE
2455

2556

26-
class Base:
27-
SCHEMA: ClassVar[BaseSchema]
28-
57+
class Base(ToDictMixin):
2958
def __init__(self, status_code: Optional[int] = None) -> None:
3059
self.status_code = status_code
3160

@@ -35,12 +64,6 @@ def to_json(self) -> str:
3564
"""
3665
return cast(str, self.SCHEMA.dumps(self))
3766

38-
def to_dict(self) -> Dict:
39-
"""
40-
to_dict converts model to a dictionary representation.
41-
"""
42-
return cast(Dict, self.SCHEMA.dump(self))
43-
4467
@property
4568
def success(self) -> bool:
4669
return self.__bool__()
@@ -122,7 +145,7 @@ def make_detail_response(self, data: Dict[str, Any], **kwargs: Any) -> "Detail":
122145
return Detail(**data)
123146

124147

125-
class Detail(Base):
148+
class Detail(Base, FromDictMixin):
126149
"""Detail is a response object mostly returned on error or when the
127150
api output is a simple string.
128151
@@ -155,7 +178,7 @@ def make_match(self, data: Dict[str, Any], **kwargs: Any) -> "Match":
155178
return Match(**data)
156179

157180

158-
class Match(Base):
181+
class Match(Base, FromDictMixin):
159182
"""
160183
Match describes a found issue by GitGuardian.
161184
With info such as match location and type.
@@ -219,7 +242,7 @@ def make_policy_break(self, data: Dict[str, Any], **kwargs: Any) -> "PolicyBreak
219242
return PolicyBreak(**data)
220243

221244

222-
class PolicyBreak(Base):
245+
class PolicyBreak(Base, FromDictMixin):
223246
"""
224247
PolicyBreak describes a GitGuardian policy break found
225248
in a scan.
@@ -269,7 +292,7 @@ def make_scan_result(self, data: Dict[str, Any], **kwargs: Any) -> "ScanResult":
269292
return ScanResult(**data)
270293

271294

272-
class ScanResult(Base):
295+
class ScanResult(Base, FromDictMixin):
273296
"""ScanResult is a response object returned on a Content Scan
274297
275298
Attributes:
@@ -355,7 +378,7 @@ def make_scan_result(
355378
return MultiScanResult(**data)
356379

357380

358-
class MultiScanResult(Base):
381+
class MultiScanResult(Base, FromDictMixin):
359382
"""ScanResult is a response object returned on a Content Scan
360383
361384
Attributes:
@@ -425,7 +448,7 @@ def make_quota(self, data: Dict[str, Any], **kwargs: Any) -> "Quota":
425448
return Quota(**data)
426449

427450

428-
class Quota(Base):
451+
class Quota(Base, FromDictMixin):
429452
"""
430453
Quota describes a quota category in the GitGuardian API.
431454
Allows you to check your current available quota.
@@ -468,7 +491,7 @@ def make_quota_response(
468491
return QuotaResponse(**data)
469492

470493

471-
class QuotaResponse(Base):
494+
class QuotaResponse(Base, FromDictMixin):
472495
"""
473496
Quota describes a quota category in the GitGuardian API.
474497
Allows you to check your current available quota.
@@ -515,7 +538,7 @@ def make_honeytoken_response(
515538
return HoneytokenResponse(**data)
516539

517540

518-
class HoneytokenResponse(Base):
541+
class HoneytokenResponse(Base, FromDictMixin):
519542
"""
520543
honeytoken creation in the GitGuardian API.
521544
Allows users to create and get a honeytoken.
@@ -633,14 +656,15 @@ class SecretScanPreferences:
633656

634657

635658
@dataclass
636-
class ServerMetadata(Base):
659+
class ServerMetadata(Base, FromDictMixin):
637660
version: str
638661
preferences: Dict[str, Any]
639662
secret_scan_preferences: SecretScanPreferences = field(
640663
default_factory=SecretScanPreferences
641664
)
642665

643666

644-
ServerMetadata.SCHEMA = marshmallow_dataclass.class_schema(
645-
ServerMetadata, base_schema=BaseSchema
646-
)()
667+
ServerMetadata.SCHEMA = cast(
668+
BaseSchema,
669+
marshmallow_dataclass.class_schema(ServerMetadata, base_schema=BaseSchema)(),
670+
)

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def get_version() -> str:
3535
"marshmallow>=3.5, <4",
3636
"requests>=2, <3",
3737
"marshmallow-dataclass >=8.5.8, <8.6.0",
38+
"typing-extensions",
3839
],
3940
include_package_data=True,
4041
zip_safe=True,

0 commit comments

Comments
 (0)