Skip to content

Commit de27d31

Browse files
committed
Refactor serialization module to simplify type conversion and enhance Pydantic integration
1 parent 4de6e4e commit de27d31

File tree

1 file changed

+86
-161
lines changed

1 file changed

+86
-161
lines changed

src/iop/_serialization.py

Lines changed: 86 additions & 161 deletions
Original file line numberDiff line numberDiff line change
@@ -1,173 +1,52 @@
11
from __future__ import annotations
2-
import base64
32
import codecs
4-
import datetime
5-
import decimal
63
import importlib
7-
import json
4+
import inspect
85
import pickle
9-
import uuid
6+
import json
7+
from dataclasses import asdict, is_dataclass
108
from typing import Any, Dict, Type
119

12-
from dacite import Config, from_dict
1310
import iris
11+
from pydantic import BaseModel, TypeAdapter
1412

1513
from iop._message import _PydanticPickleMessage
1614
from iop._utils import _Utils
17-
from pydantic import BaseModel
18-
19-
# Constants
20-
DATETIME_FORMAT_LENGTH = 23
21-
TIME_FORMAT_LENGTH = 12
22-
TYPE_SEPARATOR = ':'
23-
SUPPORTED_TYPES = {
24-
'datetime', 'date', 'time', 'dataframe',
25-
'decimal', 'uuid', 'bytes'
26-
}
2715

2816
class SerializationError(Exception):
29-
"""Base exception for serialization errors."""
17+
"""Exception raised for serialization errors."""
3018
pass
3119

32-
class TypeConverter:
33-
"""Handles type conversion for special data types."""
34-
35-
@staticmethod
36-
def convert_to_string(typ: str, obj: Any) -> str:
37-
if typ == 'dataframe':
38-
return obj.to_json(orient="table")
39-
elif typ == 'datetime':
40-
return TypeConverter._format_datetime(obj)
41-
elif typ == 'date':
42-
return obj.isoformat()
43-
elif typ == 'time':
44-
return TypeConverter._format_time(obj)
45-
elif typ == 'bytes':
46-
return base64.b64encode(obj).decode("UTF-8")
47-
return str(obj)
48-
49-
@staticmethod
50-
def convert_from_string(typ: str, val: str) -> Any:
51-
try:
52-
if typ == 'datetime':
53-
return datetime.datetime.fromisoformat(val)
54-
elif typ == 'date':
55-
return datetime.date.fromisoformat(val)
56-
elif typ == 'time':
57-
return datetime.time.fromisoformat(val)
58-
elif typ == 'dataframe':
59-
try:
60-
import pandas as pd
61-
except ImportError:
62-
raise SerializationError("Failed to load pandas module")
63-
return pd.read_json(val, orient="table")
64-
elif typ == 'decimal':
65-
return decimal.Decimal(val)
66-
elif typ == 'uuid':
67-
return uuid.UUID(val)
68-
elif typ == 'bytes':
69-
return base64.b64decode(val.encode("UTF-8"))
70-
return val
71-
except Exception as e:
72-
raise SerializationError(f"Failed to convert type {typ}: {str(e)}")
20+
class TempPydanticModel(BaseModel):
21+
model_config = {
22+
'arbitrary_types_allowed' : True,
23+
'extra' : 'allow'
24+
}
7325

74-
@staticmethod
75-
def _format_datetime(dt: datetime.datetime) -> str:
76-
r = dt.isoformat()
77-
if dt.microsecond:
78-
r = r[:DATETIME_FORMAT_LENGTH] + r[26:]
79-
if r.endswith("+00:00"):
80-
r = r[:-6] + "Z"
81-
return r
26+
class MessageSerializer:
27+
"""Handles message serialization and deserialization."""
8228

8329
@staticmethod
84-
def _format_time(t: datetime.time) -> str:
85-
r = t.isoformat()
86-
if t.microsecond:
87-
r = r[:TIME_FORMAT_LENGTH]
88-
return r
89-
90-
class IrisJSONEncoder(json.JSONEncoder):
91-
"""JSONEncoder that handles dates, decimals, UUIDs, etc."""
92-
93-
def default(self, obj: Any) -> Any:
30+
def _convert_to_json_safe(obj: Any) -> Any:
31+
"""Convert objects to JSON-safe format."""
9432
if isinstance(obj, BaseModel):
95-
return obj.model_dump()
96-
if obj.__class__.__name__ == 'DataFrame':
97-
return f'dataframe:{TypeConverter.convert_to_string("dataframe", obj)}'
98-
elif isinstance(obj, datetime.datetime):
99-
return f'datetime:{TypeConverter.convert_to_string("datetime", obj)}'
100-
elif isinstance(obj, datetime.date):
101-
return f'date:{TypeConverter.convert_to_string("date", obj)}'
102-
elif isinstance(obj, datetime.time):
103-
return f'time:{TypeConverter.convert_to_string("time", obj)}'
104-
elif isinstance(obj, decimal.Decimal):
105-
return f'decimal:{obj}'
106-
elif isinstance(obj, uuid.UUID):
107-
return f'uuid:{obj}'
108-
elif isinstance(obj, bytes):
109-
return f'bytes:{TypeConverter.convert_to_string("bytes", obj)}'
110-
elif hasattr(obj, '__dict__'):
111-
return obj.__dict__
112-
return super().default(obj)
113-
114-
class IrisJSONDecoder(json.JSONDecoder):
115-
"""JSONDecoder that handles special type annotations."""
116-
117-
def __init__(self, *args: Any, **kwargs: Any) -> None:
118-
super().__init__(object_hook=self.object_hook, *args, **kwargs)
119-
120-
def object_hook(self, obj: Dict) -> Dict:
121-
return {
122-
key: self._process_value(value)
123-
for key, value in obj.items()
124-
}
125-
126-
def _process_value(self, value: Any) -> Any:
127-
if isinstance(value, str) and TYPE_SEPARATOR in value:
128-
typ, val = value.split(TYPE_SEPARATOR, 1)
129-
if typ in SUPPORTED_TYPES:
130-
return TypeConverter.convert_from_string(typ, val)
131-
return value
132-
133-
class MessageSerializer:
134-
"""Handles message serialization and deserialization."""
33+
return obj.model_dump_json()
34+
elif is_dataclass(obj):
35+
return TempPydanticModel.model_validate(dataclass_to_dict(obj)).model_dump_json()
36+
else:
37+
raise SerializationError(f"Object {obj} must be a Pydantic model or dataclass")
13538

13639
@staticmethod
13740
def serialize(message: Any, use_pickle: bool = False) -> iris.cls:
13841
"""Serializes a message to IRIS format."""
139-
# Check for PydanticPickleMessage first
140-
if isinstance(message, _PydanticPickleMessage):
141-
return MessageSerializer._serialize_pickle(message)
142-
if isinstance(message, BaseModel):
143-
return (MessageSerializer._serialize_pickle(message)
144-
if use_pickle else MessageSerializer._serialize_json(message))
145-
if use_pickle:
42+
if isinstance(message, _PydanticPickleMessage) or use_pickle:
14643
return MessageSerializer._serialize_pickle(message)
14744
return MessageSerializer._serialize_json(message)
14845

149-
@staticmethod
150-
def deserialize(serial: iris.cls, use_pickle: bool = False) -> Any:
151-
"""Deserializes a message from IRIS format."""
152-
if use_pickle:
153-
return MessageSerializer._deserialize_pickle(serial)
154-
return MessageSerializer._deserialize_json(serial)
155-
156-
@staticmethod
157-
def _serialize_pickle(message: Any) -> iris.cls:
158-
pickle_string = codecs.encode(pickle.dumps(message), "base64").decode()
159-
msg = iris.cls('IOP.PickleMessage')._New()
160-
msg.classname = f"{message.__class__.__module__}.{message.__class__.__name__}"
161-
msg.jstr = _Utils.string_to_stream(pickle_string)
162-
return msg
163-
16446
@staticmethod
16547
def _serialize_json(message: Any) -> iris.cls:
166-
if isinstance(message, BaseModel):
167-
json_string = json.dumps(message.model_dump(), cls=IrisJSONEncoder, ensure_ascii=False)
168-
else:
169-
json_string = json.dumps(message, cls=IrisJSONEncoder, ensure_ascii=False)
170-
48+
json_string = MessageSerializer._convert_to_json_safe(message)
49+
17150
msg = iris.cls('IOP.Message')._New()
17251
msg.classname = f"{message.__class__.__module__}.{message.__class__.__name__}"
17352

@@ -178,9 +57,10 @@ def _serialize_json(message: Any) -> iris.cls:
17857
return msg
17958

18059
@staticmethod
181-
def _deserialize_pickle(serial: iris.cls) -> Any:
182-
string = _Utils.stream_to_string(serial.jstr)
183-
return pickle.loads(codecs.decode(string.encode(), "base64"))
60+
def deserialize(serial: iris.cls, use_pickle: bool = False) -> Any:
61+
if use_pickle:
62+
return MessageSerializer._deserialize_pickle(serial)
63+
return MessageSerializer._deserialize_json(serial)
18464

18565
@staticmethod
18666
def _deserialize_json(serial: iris.cls) -> Any:
@@ -198,14 +78,28 @@ def _deserialize_json(serial: iris.cls) -> Any:
19878
if serial.type == 'Stream' else serial.json)
19979

20080
try:
201-
json_dict = json.loads(json_string, cls=IrisJSONDecoder)
20281
if issubclass(msg_class, BaseModel):
203-
return msg_class.model_validate(json_dict)
82+
return msg_class.model_validate_json(json_string)
83+
elif is_dataclass(msg_class):
84+
return dataclass_from_dict(msg_class, json.loads(json_string))
20485
else:
205-
return dataclass_from_dict(msg_class, json_dict)
86+
raise SerializationError(f"Class {msg_class} must be a Pydantic model or dataclass")
20687
except Exception as e:
20788
raise SerializationError(f"Failed to deserialize JSON: {str(e)}")
20889

90+
@staticmethod
91+
def _serialize_pickle(message: Any) -> iris.cls:
92+
pickle_string = codecs.encode(pickle.dumps(message), "base64").decode()
93+
msg = iris.cls('IOP.PickleMessage')._New()
94+
msg.classname = f"{message.__class__.__module__}.{message.__class__.__name__}"
95+
msg.jstr = _Utils.string_to_stream(pickle_string)
96+
return msg
97+
98+
@staticmethod
99+
def _deserialize_pickle(serial: iris.cls) -> Any:
100+
string = _Utils.stream_to_string(serial.jstr)
101+
return pickle.loads(codecs.decode(string.encode(), "base64"))
102+
209103
@staticmethod
210104
def _parse_classname(classname: str) -> tuple[str, str]:
211105
j = classname.rindex(".")
@@ -214,18 +108,49 @@ def _parse_classname(classname: str) -> tuple[str, str]:
214108
return classname[:j], classname[j+1:]
215109

216110
def dataclass_from_dict(klass: Type, dikt: Dict) -> Any:
217-
"""Converts a dictionary to a dataclass instance."""
218-
ret = from_dict(klass, dikt, Config(check_types=False))
219-
220-
try:
221-
fieldtypes = klass.__annotations__
222-
except Exception:
223-
fieldtypes = {}
224-
225-
for key, val in dikt.items():
226-
if key not in fieldtypes:
227-
setattr(ret, key, val)
228-
return ret
111+
field_types = {
112+
key: val.annotation
113+
for key, val in inspect.signature(klass).parameters.items()
114+
}
115+
processed_dict = {}
116+
for key, val in inspect.signature(klass).parameters.items():
117+
if key not in dikt and val.default != val.empty:
118+
processed_dict[key] = val.default
119+
continue
120+
121+
value = dikt.get(key)
122+
if value is None:
123+
processed_dict[key] = None
124+
continue
125+
126+
try:
127+
field_type = field_types[key]
128+
if field_type != inspect.Parameter.empty:
129+
adapter = TypeAdapter(field_type)
130+
processed_dict[key] = adapter.validate_python(value)
131+
else:
132+
processed_dict[key] = value
133+
except Exception:
134+
processed_dict[key] = value
135+
136+
instance = klass(
137+
**processed_dict
138+
)
139+
# handle any extra fields
140+
for k, v in dikt.items():
141+
if k not in processed_dict:
142+
setattr(instance, k, v)
143+
return instance
144+
145+
def dataclass_to_dict(instance: Any) -> Dict:
146+
"""Converts a class instance to a dictionary.
147+
Handles non attended fields."""
148+
dikt = asdict(instance)
149+
# assign any extra fields
150+
for k, v in vars(instance).items():
151+
if k not in dikt:
152+
dikt[k] = v
153+
return dikt
229154

230155
# Maintain backwards compatibility
231156
serialize_pickle_message = lambda msg: MessageSerializer.serialize(msg, use_pickle=True)

0 commit comments

Comments
 (0)