Skip to content

Commit 080e565

Browse files
Add no_copy dataclass
1 parent c59724e commit 080e565

File tree

5 files changed

+42
-10
lines changed

5 files changed

+42
-10
lines changed

README.md

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,14 @@ You can create an instance of the pybind class from your original using `get_pyb
6767
## No Copy
6868

6969
What if you would like a single representation of the data, shared between C++ and python. Then fear not,
70-
`base.BaseModelNoCopy` is your friend! Deriving from this class will result in the annotations for the pydantic class
71-
being re-written using `computed_field`, with property getters and setters operating on the generated pybind class.
70+
`BaseModelNoCopy` and `dataclass` are your friends!
7271

73-
`BaseModelNoCopy.__init__` will create the corresponding pybind class, using the supplied values.
72+
Deriving from this `BaseModelNoCopy` will give you equivalent functionality of as pydantic's `BaseModel`. The
73+
annotations are re-written using `computed_field`, with property getters and setters operating on the generated pybind
74+
class, which is instantiated behind the scenes in `init`
75+
76+
`dataclass` works similarly, adding properties to the dataclass, so that the exisitng get and set functionality works
77+
seamless in accessing the generated pybind class (also set via a shimmed `init`)
7478

7579

7680
## Supported Types

pydantic_bind/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
from .base import BaseModelNoCopy, get_pybind_type, get_pybind_value
1+
from .base import BaseModelNoCopy, dataclass, get_pybind_type, get_pybind_value
22

33
__all__ = (
44
BaseModelNoCopy,
5+
dataclass,
56
get_pybind_type,
67
get_pybind_value
78
)

pydantic_bind/base.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
from dataclasses import is_dataclass
1+
from dataclasses import dataclass as orig_dataclass, is_dataclass
22
from enum import Enum, EnumType
3-
from functools import cache
3+
from functools import cache, wraps
44
from importlib import import_module
55
from pydantic import BaseModel as BaseModel, ConfigDict, computed_field
66
from pydantic.fields import ComputedFieldInfo, FieldInfo
@@ -84,9 +84,10 @@ def _get_pybind_value(obj, default_to_self: bool = True):
8484
raise UnconvertableValue("Only dataclasses and pydantic classes supported")
8585

8686

87-
def from_pybind_value(value, typ:Type):
87+
def from_pybind_value(value, typ: Type):
8888
origin = get_origin(typ)
8989
args = get_args(typ)
90+
is_dc = is_dataclass(typ)
9091

9192
if origin is Optional:
9293
typ = args[0]
@@ -95,9 +96,9 @@ def from_pybind_value(value, typ:Type):
9596

9697
if issubclass(typ, Enum):
9798
return typ[value.name]
98-
elif issubclass(typ, __IBaseModelNoCopy):
99+
elif issubclass(typ, __IBaseModelNoCopy) or (is_dc and hasattr(typ, "__no_copy__")):
99100
return typ(__pybind_impl__=value)
100-
elif is_dataclass(typ) or issubclass(typ, BaseModel):
101+
elif is_dc or issubclass(typ, BaseModel):
101102
# This is quite inefficient
102103
kwargs = {}
103104
for field_name, field_type, _ in field_info_iter(typ):
@@ -266,3 +267,28 @@ def __init__(self, **kwargs):
266267

267268
pybind_type = get_pybind_type(type(self))
268269
self.__pybind_impl = pybind_type(**kwargs)
270+
271+
272+
def __dc_init(init):
273+
@wraps(init)
274+
def wrapper(self, *args, __pybind_impl__=None, **kwargs):
275+
self.__pybind_impl = __pybind_impl__ or get_pybind_type(type(self))()
276+
return init(self, *args, **kwargs)
277+
278+
return wrapper
279+
280+
281+
def dataclass(cls=None, /, *, init=True, repr=True, eq=True, order=False,
282+
unsafe_hash=False, frozen=False, match_args=True,
283+
kw_only=False, slots=False, weakref_slot=False):
284+
285+
ret = orig_dataclass(cls, init=init, repr=repr, eq=eq, order=order, unsafe_hash=unsafe_hash, frozen=frozen,
286+
match_args=match_args, kw_only=kw_only, slots=slots, weakref_slot=weakref_slot)
287+
288+
for name, field in ret.__dataclass_fields__.items():
289+
setattr(cls, name, property(fget=_getter(name, field.type), fset=_setter(name, field.type)))
290+
291+
ret.__init__ = __dc_init(ret.__init__)
292+
ret.__no_copy__ = True
293+
294+
return ret

pydantic_bind/cpp_generator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,7 @@ def generate_class(model_class: ModelMetaclass, indent_size: int = 0, max_width:
244244
pydantic_bases = ", " + ", ".join(base.__name__ for base in base_init.keys()) if base_init else ""
245245
pydantic_init = "\n".join(args_wrapper.wrap(f"{', '.join(types)}>(), {', '.join(kwargs)}"))
246246
pydantic_def = f"""{indent}py::class_<{cls_name}{pydantic_bases}>(m, "{cls_name}")
247+
{indent}.def(py::init<>())
247248
{indent}.def(py::init<{pydantic_init})
248249
{indent}{newline_indent.join(pydantic_attrs)};"""
249250

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# ToDo: Add versioneer
44

55
setup(name="pydantic_bind",
6-
version="1.2.7",
6+
version="1.2.8",
77
description="C++/pybind generation from Pydantic classes",
88
author="Nick Young",
99
license=r"https://www.apache.org/licenses/LICENSE-2.0",

0 commit comments

Comments
 (0)