Skip to content

Commit cdf5f08

Browse files
Msgpack now working
Equality etc fixed
1 parent c8198ef commit cdf5f08

File tree

6 files changed

+281
-85
lines changed

6 files changed

+281
-85
lines changed

README.md

Lines changed: 55 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@
22

33
# Table of Contents
44
1. [Overview](#Overview)
5-
2. [No Copy](#No-Copy)
6-
3. [Supported Types](#Supported-Types)
7-
4. [Inheritance](#Inheritance)
8-
5. [Msgpack](#Msgpack)
9-
6. [Generated Code](#Generated-Code)
5+
2. [Why Not Protobufs ?](#Why-Not-Protobufs)
6+
3. [No Copy](#No-Copy)
7+
4. [Supported Types](#Supported-Types)
8+
5. [Inheritance](#Inheritance)
9+
6. [Msgpack](#Msgpack)
10+
7. [Generated Code](#Generated-Code)
1011

1112

1213
## Overview
@@ -26,7 +27,8 @@ and [pybind11](https://pybind11.readthedocs.io/en/stable/index.html) code for bi
2627
corresponding .cpp file.
2728

2829
The intended use of this package is for defining behaviour-less data classes, to be shared between python and C++. E.g.,
29-
a common object model for financial modelling.
30+
a common object model for financial modelling. Furthr, we want idiomatic classes for each language, not mutants like
31+
Protobuf-generated python classes.
3032

3133
Note that the typcal python developer experience is now somewhat changed, in that it's necessary to build/install
3234
the project. I personally use JetBrains CLion, in place of PyCharm for such projects.
@@ -64,17 +66,37 @@ You can create an instance of the pybind class from your original using `get_pyb
6466
generated = get_pybind_value(orig)
6567

6668

69+
# Why Not Protobufs?
70+
71+
A very good question. Protobufs are frankly a PITA to use: they have poor to no variant support, the generated
72+
code is ugly and idiosyncratic, they're large and painful to copy around etc.
73+
74+
AVRO is more friendly but generates python classes dynamically, which confuses IDEs like Pycharm. I do think a good
75+
solution is something like [pydantic_avro](https://github.com/godatadriven/pydantic-avro/tree/main/src/pydantic_avro)
76+
where one can define the classes using pydantic, generate the AVRO schema and then the generateed C++ etc. I might
77+
well try and converge this project with that approach.
78+
79+
I was inspired to some degree by this [blog](https://mikeloomisgg.github.io/2019-07-02-making-a-serialization-library/).
80+
81+
6782
## No Copy
6883

69-
What if you would like a single representation of the data, shared between C++ and python. Then fear not,
70-
`BaseModelNoCopy` and `dataclass` are your friends!
84+
One annoyance of multi-language representations of data objects is that you often end up copying data around where
85+
you'd prefer to share a single copy. This is the raison d'etre for Protobufs and its ilk. In this project I've created
86+
implementations of `BaseModel` and `dataclass` which allow python to use the underlying C++ data representation, rather
87+
than holding its own copy.
7188

72-
Deriving from this `BaseModelNoCopy` will give you equivalent functionality of as pydantic's `BaseModel`. The
89+
Deriving from this `BaseModel` will give you equivalent functionality of as pydantic's `BaseModel`. The
7390
annotations are re-written using `computed_field`, with property getters and setters operating on the generated pybind
74-
class, which is instantiated behind the scenes in `init`
91+
class, which is instantiated behind the scenes in `init`. Note that this will make some operations (especially those
92+
that access __dict__) less efficient. I've also plumbed the computed fields into the JSON schema, so these objects can
93+
be used with [FastAPI](https://fastapi.tiangolo.com).
7594

7695
`dataclass` works similarly, adding properties to the dataclass, so that the exisitng get and set functionality works
77-
seamless in accessing the generated pybind class (also set via a shimmed `init`)
96+
seamless in accessing the generated pybind class (also set via a shimmed `init`).
97+
98+
Using regular `dataclass` or `BaseModel` as members of classes defined with the pydantic_bind versions is very
99+
inefficient and not recommended.
78100

79101

80102
## Supported Types
@@ -92,11 +114,12 @@ The following python -> C++ mappings are supported (there are likely others I sh
92114
- pydantic.BaseModel --> struct
93115
- pydantic_bind.BaseModelNoCopy --> struct
94116
- dataclass --> struct
117+
- Enum -> enum
95118

96119
## Inheritance
97120

98121
I have tested single inheritance (see [Generated Code](#Generated-code)). Multiple inheritance may work ... or it
99-
may not. I'd generally advise against using it for data classes
122+
may not. I'd generally advise against using it for data classes.
100123

101124

102125
## Msgpack
@@ -109,21 +132,28 @@ project with my rather rudimentary cmake skillz!) Changes include:
109132
- Fixing includes
110133
- Support for std::optional
111134
- Support for std::variant
135+
- Support for enums
136+
137+
A likely future enhancement will be to use [cereal](https://github.com/USCiLab/cereal) and add a mgspack adaptor.
138+
However, I haven't quite worked out how to do that yet.
112139

113140

114141
## Generated Code
115142

116-
Code is generated into a directory structure underneath `<top level>/generated`
143+
Code is generated into a directory structure underneath `<top level>/generated`.
117144

118-
Headers are installed to `<top level>/include`
145+
Headers are installed to `<top level>/include`.
119146

120-
Compiled pybind modules are installed into `<original module path>/__pybind__`
147+
Compiled pybind modules are installed into `<original module path>/__pybind__`.
121148

122-
For C++ usage, you need only the headers, the compiled code is for pybind/python usage only
149+
For C++ usage, you need only the headers, the compiled code is for pybind/python usage only.
123150

124151
For the example below, `common_object_model/common_object_model/v1/common/__pybind__/foo.cpython-311-darwin.so` will
125152
be installed (obviously with corresponding qualifiers for Linux/Windows). `get_pybind_value()` searches this
126-
directory
153+
directory.
154+
155+
Imports/includes should work seamlessly (the python import scheme will be copied). I have tested this but not
156+
completely rigorously.
127157

128158
*common_object_model/common_object_model/v1/common/foo.py:*
129159

@@ -290,18 +320,24 @@ will generate the following files:
290320
py::class_<DCFoo>(m, "DCFoo")
291321
.def(py::init<>())
292322
.def(py::init<std::optional<std::string>, int>(), py::arg("my_string"), py::arg("my_int"))
323+
.def("to_msg_pack", &DCFoo::to_msg_pack)
324+
.def_static("from_msg_pack", &DCFoo::from_msg_pack<Baz>)
293325
.def_readwrite("my_string", &DCFoo::my_string)
294326
.def_readwrite("my_int", &DCFoo::my_int);
295327

296328
py::class_<Foo>(m, "Foo")
297329
.def(py::init<bool, Weekday>(), py::arg("my_bool")=true, py::arg("my_day")=SUNDAY)
330+
.def("to_msg_pack", &Foo::to_msg_pack)
331+
.def_static("from_msg_pack", &Foo::from_msg_pack<Baz>)
298332
.def_readwrite("my_bool", &Foo::my_bool)
299333
.def_readwrite("my_day", &Foo::my_day);
300334

301335
py::class_<Bar>(m, "Bar")
302336
.def(py::init<>())
303337
.def(py::init<std::string, bool, Weekday, int, std::optional<std::string>>(), py::arg("my_string"), py::arg("my_bool")=true,
304338
py::arg("my_day")=SUNDAY, py::arg("my_int")=123, py::arg("my_optional_string")=std::nullopt)
339+
.def("to_msg_pack", &Bazr:to_msg_pack)
340+
.def_static("from_msg_pack", &Bar::from_msg_pack<Baz>)
305341
.def_readwrite("my_string", &Bar::my_string)
306342
.def_readwrite("my_int", &Bar::my_int)
307343
.def_readwrite("my_optional_string", &Bar::my_optional_string);
@@ -310,6 +346,8 @@ will generate the following files:
310346
.def(py::init<>())
311347
.def(py::init<DCFoo, Foo, std::chrono::system_clock::time_point, std::variant<std::string, double>>(), py::arg("my_dc_foo"),
312348
py::arg("my_foo"), py::arg("my_date"), py::arg("my_variant")=123.0)
349+
.def("to_msg_pack", &Baz::to_msg_pack)
350+
.def_static("from_msg_pack", &Baz::from_msg_pack<Baz>)
313351
.def_readwrite("my_dc_foo", &Baz::my_dc_foo)
314352
.def_readwrite("my_foo", &Baz::my_foo)
315353
.def_readwrite("my_date", &Baz::my_date)

pydantic_bind/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
from .base import BaseModelNoCopy, dataclass, get_pybind_type, get_pybind_value
1+
from .base import BaseModel, dataclass, get_pybind_type, get_pybind_value
22

33
__all__ = (
4-
BaseModelNoCopy,
4+
BaseModel,
55
dataclass,
66
get_pybind_type,
77
get_pybind_value

pydantic_bind/base.py

Lines changed: 84 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from enum import Enum, EnumType
33
from functools import cache, wraps
44
from importlib import import_module
5-
from pydantic import BaseModel as BaseModel, ConfigDict, computed_field
5+
from pydantic import BaseModel as PydanticBaseModel, ConfigDict, computed_field
66
from pydantic.fields import ComputedFieldInfo, FieldInfo
77
from pydantic.json_schema import GenerateJsonSchema
88
from pydantic._internal._config import ConfigWrapper
@@ -13,27 +13,24 @@
1313
from pydantic_core import PydanticUndefined
1414
import sys
1515
from types import UnionType
16-
from typing import Any, Dict, List, Type, Optional, Union, cast, get_args, get_origin
16+
from typing import Any, Dict, List, Optional, Sequence, Type, Union, cast, get_args, get_origin
1717

1818

1919
class UnconvertableValue(Exception):
2020
pass
2121

2222

23-
class __IBaseModelNoCopy:
24-
pass
25-
26-
27-
def field_info_iter(model_class: ModelMetaclass):
23+
def field_info_iter(model_class):
2824
if is_dataclass(model_class):
2925
for field_name, field in model_class.__dataclass_fields__.items():
3026
yield field_name, field.type, field.default
31-
elif issubclass(model_class, BaseModelNoCopy):
32-
for field_name, field in model_class.__pydantic_decorators__.computed_fields.items():
33-
yield field_name, field.info.return_type, field.info.default
34-
else:
35-
for field_name, field in model_class.model_fields.items():
36-
yield field_name, field.annotation, field.default
27+
elif issubclass(model_class, PydanticBaseModel):
28+
if hasattr(model_class, "__has_pybind_impl__"):
29+
for field_name, field in model_class.__pydantic_decorators__.computed_fields.items():
30+
yield field_name, field.info.return_type, field.info.default
31+
else:
32+
for field_name, field in model_class.model_fields.items():
33+
yield field_name, field.annotation, field.default
3734

3835

3936
@cache
@@ -69,41 +66,45 @@ def get_pybind_value(obj):
6966
def _get_pybind_value(obj, default_to_self: bool = True):
7067
if isinstance(obj, Enum):
7168
return get_pybind_type(type(obj)).__entries[obj.name][0]
72-
elif is_dataclass(obj):
73-
return get_pybind_type(type(obj))(**{name: _get_pybind_value(getattr(obj, name))
74-
for name in obj.__dataclass_fields__.keys()})
75-
elif isinstance(obj, __IBaseModelNoCopy):
76-
return get_pybind_type(type(obj))(**{name: _get_pybind_value(getattr(obj, name))
77-
for name in obj.model_computed_fields.keys()})
78-
elif isinstance(obj, BaseModel):
79-
return get_pybind_type(type(obj))(**{name: _get_pybind_value(getattr(obj, name))
80-
for name in obj.model_fields.keys()})
69+
elif is_dataclass(obj) or isinstance(obj, PydanticBaseModel):
70+
typ = type(obj)
71+
pybind_type = get_pybind_type(typ)
72+
name_iter = (name for name, _, _ in field_info_iter(typ))
73+
74+
if hasattr(typ, "__has_pybind_impl__"):
75+
return pybind_type(**{name: getattr(obj.pybind_impl, name) for name in name_iter})
76+
else:
77+
return pybind_type(**{name: _get_pybind_value(getattr(obj, name)) for name in name_iter})
8178
elif default_to_self:
8279
return obj
8380
else:
84-
raise UnconvertableValue("Only dataclasses and pydantic classes supported")
81+
raise UnconvertableValue("Only builtins, dataclasses and pydantic classes supported")
8582

8683

8784
def from_pybind_value(value, typ: Type):
8885
origin = get_origin(typ)
8986
args = get_args(typ)
90-
is_dc = is_dataclass(typ)
9187

9288
if origin is Optional:
9389
typ = args[0]
94-
elif origin in (Union, UnionType):
90+
args = get_args(typ)
91+
92+
if origin in (Union, UnionType):
9593
typ = next(a for a in args if a.__name__ == type(value).__name__)
9694

95+
is_dc_or_pydantic = is_dataclass(typ) or issubclass(typ, PydanticBaseModel)
96+
9797
if issubclass(typ, Enum):
9898
return typ[value.name]
99-
elif issubclass(typ, __IBaseModelNoCopy) or (is_dc and hasattr(typ, "__no_copy__")):
100-
return typ(__pybind_impl__=value)
101-
elif is_dc or issubclass(typ, BaseModel):
102-
# This is quite inefficient
103-
kwargs = {}
104-
for field_name, field_type, _ in field_info_iter(typ):
105-
kwargs[field_name] = from_pybind_value(getattr(value, field_name), field_type)
106-
return typ(**kwargs)
99+
elif is_dc_or_pydantic:
100+
if hasattr(typ, "__has_pybind_impl__"):
101+
return typ(__pybind_impl__=value)
102+
else:
103+
# This is quite inefficient
104+
kwargs = {}
105+
for field_name, field_type, _ in field_info_iter(typ):
106+
kwargs[field_name] = from_pybind_value(getattr(value, field_name), field_type)
107+
return typ(**kwargs)
107108
else:
108109
return value
109110

@@ -194,6 +195,7 @@ def __new__(
194195
namespace[name] = prop
195196

196197
cls = cast(ModelMetaclass, super().__new__(mcs, cls_name, bases, namespace, **kwargs))
198+
cls.__has_pybind_impl__ = True
197199
cls.__pydantic_decorators__.__annotations__["computed_fields"] = dict[str, Decorator[PropertyFieldInfo]]
198200
cls.__signature__ = ClassAttribute(
199201
'__signature__', generate_model_signature(cls.__init__, field_infos, config_wrapper)
@@ -226,7 +228,13 @@ def json_schema_extra(schema: Dict[str, Any], model_class: ModelMetaclassNoCopy)
226228
properties[alias] = field_schema
227229

228230

229-
class BaseModelNoCopy(BaseModel, __IBaseModelNoCopy, metaclass=ModelMetaclassNoCopy):
231+
def _from_msg_pack(cls, data: Sequence[int]):
232+
typ = get_pybind_type(cls)
233+
pybind_impl, _error_code = typ.from_msg_pack(data)
234+
return cls(__pybind_impl__=pybind_impl)
235+
236+
237+
class BaseModel(PydanticBaseModel, metaclass=ModelMetaclassNoCopy):
230238
model_config = ConfigDict(json_schema_extra=json_schema_extra)
231239

232240
@property
@@ -237,9 +245,14 @@ def model_computed_fields(self) -> dict[str, PropertyFieldInfo]:
237245
def pybind_impl(self):
238246
return self.__pybind_impl
239247

240-
def __init__(self, **kwargs):
241-
super().__init__()
248+
def to_msg_pack(self):
249+
return self.__pybind_impl.to_msg_pack()
250+
251+
@classmethod
252+
def from_msg_pack(cls, data: Sequence[int]):
253+
return _from_msg_pack(cls, data)
242254

255+
def __init__(self, **kwargs):
243256
__pybind_impl__ = kwargs.pop("__pybind_impl__", None)
244257
if __pybind_impl__:
245258
self.__pybind_impl = __pybind_impl__
@@ -266,19 +279,45 @@ def __init__(self, **kwargs):
266279
raise RuntimeError(f"Missing required fields: {missing_required}")
267280

268281
pybind_type = get_pybind_type(type(self))
269-
self.__pybind_impl = pybind_type(**kwargs)
282+
object.__setattr__(self, "_BaseModel__pybind_impl", pybind_type(**kwargs))
283+
284+
super().__init__()
285+
286+
@property
287+
def __dict__(self):
288+
return {name: from_pybind_value(getattr(self, name), typ) for name, typ, _ in field_info_iter(type(self))}
289+
290+
@__dict__.setter
291+
def __dict__(self, value: dict):
292+
try:
293+
object.__getattribute__(self, "_BaseModel__pybind_impl")
294+
for name, value in value.items():
295+
object.__setattr__(self, name, value)
296+
except AttributeError:
297+
self.__init__(**value)
270298

271299

272300
def __dataclass_init(init):
273301
@wraps(init)
274302
def wrapper(self, *args, __pybind_impl__=None, **kwargs):
275-
self.pybind_impl = __pybind_impl__ or get_pybind_type(type(self))()
276-
init(self, *args, **kwargs)
277-
arse = True
303+
if __pybind_impl__:
304+
self.__pybind_impl = __pybind_impl__
305+
else:
306+
self.__pybind_impl = get_pybind_type(type(self))()
307+
init(self, *args, **kwargs)
278308

279309
return wrapper
280310

281311

312+
def to_msg_pack(self):
313+
return self.pybind_impl.to_msg_pack()
314+
315+
316+
@classmethod
317+
def from_msg_pack(cls, data: Sequence[int]):
318+
return _from_msg_pack(cls, data)
319+
320+
282321
def dataclass(cls=None, /, *, init=True, repr=True, eq=True, order=False,
283322
unsafe_hash=False, frozen=False, match_args=True,
284323
kw_only=False, slots=False, weakref_slot=False):
@@ -290,6 +329,10 @@ def dataclass(cls=None, /, *, init=True, repr=True, eq=True, order=False,
290329
setattr(cls, name, property(fget=_getter(name, field.type), fset=_setter(name, field.type)))
291330

292331
ret.__init__ = __dataclass_init(ret.__init__)
293-
ret.__no_copy__ = True
332+
ret.__has_pybind_impl__ = True
333+
334+
ret.to_msg_pack = to_msg_pack
335+
ret.from_msg_pack = from_msg_pack
336+
ret.pybind_impl = property(fget=lambda self: self.__pybind_impl)
294337

295338
return ret

0 commit comments

Comments
 (0)