Skip to content

Commit 2195111

Browse files
committed
improved to/from_dict/string() implementation
1 parent 5f9efc5 commit 2195111

File tree

2 files changed

+71
-8
lines changed

2 files changed

+71
-8
lines changed

airbyte_cdk/models/airbyte_protocol.py

Lines changed: 68 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@
22
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
33
#
44

5-
from dataclasses import InitVar, asdict, dataclass
5+
from dataclasses import InitVar, dataclass
6+
from functools import cached_property
67
from typing import Annotated, Any, Dict, List, Mapping, Optional, Union
78

89
import orjson
910
from airbyte_protocol_dataclasses.models import * # noqa: F403 # Allow '*'
11+
from serpyco_rs import CustomType, Serializer
1012
from serpyco_rs.metadata import Alias
1113

1214
# ruff: noqa: F405 # ignore fuzzy import issues with 'import *'
@@ -51,6 +53,22 @@ def __eq__(self, other: object) -> bool:
5153
)
5254

5355

56+
class AirbyteStateBlobType(CustomType[AirbyteStateBlob, dict[str, Any]]):
57+
def serialize(self, value: AirbyteStateBlob) -> dict[str, Any]:
58+
# cant use orjson.dumps() directly because private attributes are excluded, e.g. "__ab_full_refresh_sync_complete"
59+
return {k: v for k, v in value.__dict__.items()}
60+
61+
def deserialize(self, value: dict[str, Any]) -> AirbyteStateBlob:
62+
return AirbyteStateBlob(value)
63+
64+
def get_json_schema(self) -> dict[str, Any]:
65+
return {"type": "object"}
66+
67+
68+
def custom_type_resolver(t: type) -> CustomType[AirbyteStateBlob, dict[str, Any]] | None:
69+
return AirbyteStateBlobType() if t is AirbyteStateBlob else None
70+
71+
5472
# The following dataclasses have been redeclared to include the new version of AirbyteStateBlob
5573
@dataclass
5674
class AirbyteStreamState:
@@ -75,6 +93,33 @@ class AirbyteStateMessage:
7593
sourceStats: Optional[AirbyteStateStats] = None # type: ignore [name-defined]
7694
destinationStats: Optional[AirbyteStateStats] = None # type: ignore [name-defined]
7795

96+
def to_dict(self) -> dict:
97+
return self._serializer.dump(self)
98+
99+
def to_string(self) -> str:
100+
return orjson.dumps(self.to_dict()).decode("utf-8")
101+
102+
def from_string(self, string: str, /) -> "AirbyteMessage":
103+
"""Deserialize a string into an AirbyteMessage object."""
104+
return self._serializer.load(orjson.loads(string))
105+
106+
def from_dict(self, dictionary: dict, /) -> "AirbyteMessage":
107+
"""Deserialize a dictionary into an AirbyteMessage object."""
108+
return self._serializer.load(dictionary)
109+
110+
@cached_property
111+
@classmethod
112+
def _serializer(cls) -> Serializer:
113+
"""
114+
Returns a serializer for the AirbyteMessage class.
115+
The serializer is cached for performance reasons.
116+
"""
117+
return Serializer(
118+
AirbyteStateMessage,
119+
omit_none=True,
120+
custom_type_resolver=custom_type_resolver,
121+
)
122+
78123

79124
@dataclass
80125
class AirbyteMessage:
@@ -89,7 +134,28 @@ class AirbyteMessage:
89134
control: Optional[AirbyteControlMessage] = None # type: ignore [name-defined]
90135

91136
def to_dict(self) -> dict:
92-
return asdict(self)
137+
return self._serializer.dump(self)
93138

94139
def to_string(self) -> str:
95140
return orjson.dumps(self.to_dict()).decode("utf-8")
141+
142+
def from_string(self, string: str, /) -> "AirbyteMessage":
143+
"""Deserialize a string into an AirbyteMessage object."""
144+
return self._serializer.load(orjson.loads(string))
145+
146+
def from_dict(self, dictionary: dict, /) -> "AirbyteMessage":
147+
"""Deserialize a dictionary into an AirbyteMessage object."""
148+
return self._serializer.load(dictionary)
149+
150+
@cached_property
151+
@classmethod
152+
def _serializer(cls) -> Serializer:
153+
"""
154+
Returns a serializer for the AirbyteMessage class.
155+
The serializer is cached for performance reasons.
156+
"""
157+
return Serializer(
158+
AirbyteMessage,
159+
omit_none=True,
160+
custom_type_resolver=custom_type_resolver,
161+
)

airbyte_cdk/models/airbyte_protocol_serializers.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,9 @@ def custom_type_resolver(t: type) -> CustomType[AirbyteStateBlob, Dict[str, Any]
3333
AirbyteStreamStateSerializer = Serializer(
3434
AirbyteStreamState, omit_none=True, custom_type_resolver=custom_type_resolver
3535
)
36-
AirbyteStateMessageSerializer = Serializer(
37-
AirbyteStateMessage, omit_none=True, custom_type_resolver=custom_type_resolver
38-
)
39-
AirbyteMessageSerializer = Serializer(
40-
AirbyteMessage, omit_none=True, custom_type_resolver=custom_type_resolver
41-
)
4236
ConfiguredAirbyteCatalogSerializer = Serializer(ConfiguredAirbyteCatalog, omit_none=True)
4337
ConfiguredAirbyteStreamSerializer = Serializer(ConfiguredAirbyteStream, omit_none=True)
4438
ConnectorSpecificationSerializer = Serializer(ConnectorSpecification, omit_none=True)
39+
40+
AirbyteStateMessageSerializer = AirbyteStateMessage._serializer
41+
AirbyteMessageSerializer = AirbyteMessage._serializer

0 commit comments

Comments
 (0)