Skip to content

Commit 607e2f2

Browse files
committed
make mypy happy with mixin classes
1 parent 2def5fc commit 607e2f2

File tree

1 file changed

+79
-39
lines changed

1 file changed

+79
-39
lines changed

airbyte_cdk/models/airbyte_protocol.py

Lines changed: 79 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,15 @@
11
# Copyright (c) 2025 Airbyte, Inc., all rights reserved.
2-
from collections.abc import Callable
2+
from collections.abc import Callable, Mapping
33
from dataclasses import InitVar, dataclass
4-
from typing import Annotated, Any, Dict, List, Mapping, Optional, Union
4+
from typing import (
5+
Annotated,
6+
Any,
7+
Dict,
8+
List,
9+
Optional,
10+
Type,
11+
TypeVar,
12+
)
513

614
import orjson
715
from airbyte_protocol_dataclasses.models import * # noqa: F403 # Allow '*'
@@ -48,21 +56,74 @@ def __eq__(self, other: object) -> bool:
4856
)
4957

5058

59+
T = TypeVar("T", bound="SerDeMixin")
60+
61+
62+
class SerDeMixin:
63+
# allow subclasses to override their resolver if they need one
64+
_type_resolver: Callable[[type], CustomType[Any, Any] | None] | None = None
65+
66+
def __init_subclass__(cls, **kwargs):
67+
super().__init_subclass__(**kwargs)
68+
# build a Serializer *once* for each subclass
69+
cls._serializer = Serializer(
70+
cls,
71+
omit_none=True,
72+
custom_type_resolver=cls._type_resolver,
73+
)
74+
75+
def to_dict(self) -> Dict[str, Any]:
76+
return self._serializer.dump(self)
77+
78+
def to_json(self) -> str:
79+
# use to_dict so you only have one canonical dump
80+
return orjson.dumps(self.to_dict()).decode("utf-8")
81+
82+
@classmethod
83+
def from_dict(cls: type[T], data: Dict[str, Any]) -> T:
84+
return cls._serializer.load(data)
85+
86+
@classmethod
87+
def from_json(cls: type[T], s: str) -> T:
88+
return cls._serializer.load(orjson.loads(s))
89+
90+
91+
def _custom_state_resolver(t: type) -> CustomType[AirbyteStateBlob, dict[str, Any]] | None:
92+
class AirbyteStateBlobType(CustomType[AirbyteStateBlob, Dict[str, Any]]):
93+
def serialize(self, value: AirbyteStateBlob) -> Dict[str, Any]:
94+
# cant use orjson.dumps() directly because private attributes are excluded, e.g. "__ab_full_refresh_sync_complete"
95+
return {k: v for k, v in value.__dict__.items()}
96+
97+
def deserialize(self, value: Dict[str, Any]) -> AirbyteStateBlob:
98+
return AirbyteStateBlob(value)
99+
100+
def get_json_schema(self) -> Dict[str, Any]:
101+
return {"type": "object"}
102+
103+
return AirbyteStateBlobType() if t is AirbyteStateBlob else None
104+
105+
51106
# The following dataclasses have been redeclared to include the new version of AirbyteStateBlob
52107
@dataclass
53-
class AirbyteStreamState:
108+
class AirbyteStreamState(SerDeMixin):
54109
stream_descriptor: StreamDescriptor # type: ignore [name-defined]
55110
stream_state: Optional[AirbyteStateBlob] = None
56111

112+
# override the resolver for AirbyteStreamState to use the custom one
113+
_type_resolver = _custom_state_resolver
114+
57115

58116
@dataclass
59-
class AirbyteGlobalState:
117+
class AirbyteGlobalState(SerDeMixin):
60118
stream_states: List[AirbyteStreamState]
61119
shared_state: Optional[AirbyteStateBlob] = None
62120

121+
# override the resolver for AirbyteStreamState to use the custom one
122+
_type_resolver = _custom_state_resolver
123+
63124

64125
@dataclass
65-
class AirbyteStateMessage:
126+
class AirbyteStateMessage(SerDeMixin):
66127
type: Optional[AirbyteStateType] = None # type: ignore [name-defined]
67128
stream: Optional[AirbyteStreamState] = None
68129
global_: Annotated[AirbyteGlobalState | None, Alias("global")] = (
@@ -72,9 +133,12 @@ class AirbyteStateMessage:
72133
sourceStats: Optional[AirbyteStateStats] = None # type: ignore [name-defined]
73134
destinationStats: Optional[AirbyteStateStats] = None # type: ignore [name-defined]
74135

136+
# override the resolver for AirbyteStreamState to use the custom one
137+
_type_resolver = _custom_state_resolver
138+
75139

76140
@dataclass
77-
class AirbyteMessage:
141+
class AirbyteMessage(SerDeMixin):
78142
type: Type # type: ignore [name-defined]
79143
log: Optional[AirbyteLogMessage] = None # type: ignore [name-defined]
80144
spec: Optional[ConnectorSpecification] = None # type: ignore [name-defined]
@@ -85,42 +149,18 @@ class AirbyteMessage:
85149
trace: Optional[AirbyteTraceMessage] = None # type: ignore [name-defined]
86150
control: Optional[AirbyteControlMessage] = None # type: ignore [name-defined]
87151

152+
# override the resolver for AirbyteStreamState to use the custom one
153+
_type_resolver = _custom_state_resolver
88154

89-
# Add optimized serdes methods to the protocol data classes:
90155

91-
def _with_serdes(
92-
cls,
93-
type_resolver: Callable[[type], CustomType[Any, Any] | None] | None = None,
94-
) -> type:
95-
"""Decorator to add SerDes (serialize/deserialize) methods to a dataclass."""
96-
cls._serializer = Serializer(cls, omit_none=True, custom_type_resolver=type_resolver)
97-
cls.to_dict = lambda self: self._serializer.dump(self)
98-
cls.to_json = lambda self: orjson.dumps(self._serializer.dump(self)).decode("utf-8")
99-
cls.from_json = lambda self, string: self._serializer.load(orjson.loads(string))
100-
cls.from_dict = lambda self, dictionary: self._serializer.load(dictionary)
101-
return cls
102-
103-
104-
def _custom_state_resolver(t: type) -> CustomType[AirbyteStateBlob, dict[str, Any]] | None:
105-
class AirbyteStateBlobType(CustomType[AirbyteStateBlob, Dict[str, Any]]):
106-
def serialize(self, value: AirbyteStateBlob) -> Dict[str, Any]:
107-
# cant use orjson.dumps() directly because private attributes are excluded, e.g. "__ab_full_refresh_sync_complete"
108-
return {k: v for k, v in value.__dict__.items()}
109-
110-
def deserialize(self, value: Dict[str, Any]) -> AirbyteStateBlob:
111-
return AirbyteStateBlob(value)
156+
# These don't need the custom resolver:
157+
class ConnectorSpecification(ConnectorSpecification, SerDeMixin):
158+
pass
112159

113-
def get_json_schema(self) -> Dict[str, Any]:
114-
return {"type": "object"}
115160

116-
return AirbyteStateBlobType() if t is AirbyteStateBlob else None
161+
class ConfiguredAirbyteCatalog(ConfiguredAirbyteCatalog, SerDeMixin):
162+
pass
117163

118164

119-
# Add serdes capabilities to all data classes that need to serialize and deserialize:
120-
AirbyteMessage = _with_serdes(AirbyteMessage, _custom_state_resolver)
121-
AirbyteStateMessage = _with_serdes(AirbyteStateMessage, _custom_state_resolver)
122-
AirbyteStreamState = _with_serdes(AirbyteStreamState, _custom_state_resolver)
123-
# These don't need the custom resolver:
124-
ConnectorSpecification = _with_serdes(ConnectorSpecification)
125-
ConfiguredAirbyteCatalog = _with_serdes(ConfiguredAirbyteCatalog)
126-
AirbyteStream = _with_serdes(AirbyteStream)
165+
class AirbyteStream(AirbyteStream, SerDeMixin):
166+
pass

0 commit comments

Comments
 (0)