22from enum import Enum , EnumType
33from functools import cache , wraps
44from importlib import import_module
5- from pydantic import BaseModel as BaseModel , ConfigDict , computed_field
5+ from pydantic import BaseModel as PydanticBaseModel , ConfigDict , computed_field
66from pydantic .fields import ComputedFieldInfo , FieldInfo
77from pydantic .json_schema import GenerateJsonSchema
88from pydantic ._internal ._config import ConfigWrapper
1313from pydantic_core import PydanticUndefined
1414import sys
1515from types import UnionType
16- from typing import Any , Dict , List , Type , Optional , Union , cast , get_args , get_origin
16+ from typing import Any , Dict , List , Optional , Sequence , Type , Union , cast , get_args , get_origin
1717
1818
1919class UnconvertableValue (Exception ):
2020 pass
2121
2222
23- class __IBaseModelNoCopy :
24- pass
25-
26-
27- def field_info_iter (model_class : ModelMetaclass ):
23+ def field_info_iter (model_class ):
2824 if is_dataclass (model_class ):
2925 for field_name , field in model_class .__dataclass_fields__ .items ():
3026 yield field_name , field .type , field .default
31- elif issubclass (model_class , BaseModelNoCopy ):
32- for field_name , field in model_class .__pydantic_decorators__ .computed_fields .items ():
33- yield field_name , field .info .return_type , field .info .default
34- else :
35- for field_name , field in model_class .model_fields .items ():
36- yield field_name , field .annotation , field .default
27+ elif issubclass (model_class , PydanticBaseModel ):
28+ if hasattr (model_class , "__has_pybind_impl__" ):
29+ for field_name , field in model_class .__pydantic_decorators__ .computed_fields .items ():
30+ yield field_name , field .info .return_type , field .info .default
31+ else :
32+ for field_name , field in model_class .model_fields .items ():
33+ yield field_name , field .annotation , field .default
3734
3835
3936@cache
@@ -69,41 +66,45 @@ def get_pybind_value(obj):
6966def _get_pybind_value (obj , default_to_self : bool = True ):
7067 if isinstance (obj , Enum ):
7168 return get_pybind_type (type (obj )).__entries [obj .name ][0 ]
72- elif is_dataclass (obj ):
73- return get_pybind_type ( type (obj ))( ** { name : _get_pybind_value ( getattr ( obj , name ) )
74- for name in obj . __dataclass_fields__ . keys ()} )
75- elif isinstance ( obj , __IBaseModelNoCopy ):
76- return get_pybind_type ( type ( obj ))( ** { name : _get_pybind_value ( getattr ( obj , name ))
77- for name in obj . model_computed_fields . keys ()})
78- elif isinstance ( obj , BaseModel ):
79- return get_pybind_type ( type ( obj ))( ** { name : _get_pybind_value ( getattr ( obj , name ))
80- for name in obj . model_fields . keys () })
69+ elif is_dataclass (obj ) or isinstance ( obj , PydanticBaseModel ) :
70+ typ = type (obj )
71+ pybind_type = get_pybind_type ( typ )
72+ name_iter = ( name for name , _ , _ in field_info_iter ( typ ))
73+
74+ if hasattr ( typ , "__has_pybind_impl__" ):
75+ return pybind_type ( ** { name : getattr ( obj . pybind_impl , name ) for name in name_iter })
76+ else :
77+ return pybind_type ( ** { name : _get_pybind_value ( getattr ( obj , name )) for name in name_iter })
8178 elif default_to_self :
8279 return obj
8380 else :
84- raise UnconvertableValue ("Only dataclasses and pydantic classes supported" )
81+ raise UnconvertableValue ("Only builtins, dataclasses and pydantic classes supported" )
8582
8683
8784def from_pybind_value (value , typ : Type ):
8885 origin = get_origin (typ )
8986 args = get_args (typ )
90- is_dc = is_dataclass (typ )
9187
9288 if origin is Optional :
9389 typ = args [0 ]
94- elif origin in (Union , UnionType ):
90+ args = get_args (typ )
91+
92+ if origin in (Union , UnionType ):
9593 typ = next (a for a in args if a .__name__ == type (value ).__name__ )
9694
95+ is_dc_or_pydantic = is_dataclass (typ ) or issubclass (typ , PydanticBaseModel )
96+
9797 if issubclass (typ , Enum ):
9898 return typ [value .name ]
99- elif issubclass (typ , __IBaseModelNoCopy ) or (is_dc and hasattr (typ , "__no_copy__" )):
100- return typ (__pybind_impl__ = value )
101- elif is_dc or issubclass (typ , BaseModel ):
102- # This is quite inefficient
103- kwargs = {}
104- for field_name , field_type , _ in field_info_iter (typ ):
105- kwargs [field_name ] = from_pybind_value (getattr (value , field_name ), field_type )
106- return typ (** kwargs )
99+ elif is_dc_or_pydantic :
100+ if hasattr (typ , "__has_pybind_impl__" ):
101+ return typ (__pybind_impl__ = value )
102+ else :
103+ # This is quite inefficient
104+ kwargs = {}
105+ for field_name , field_type , _ in field_info_iter (typ ):
106+ kwargs [field_name ] = from_pybind_value (getattr (value , field_name ), field_type )
107+ return typ (** kwargs )
107108 else :
108109 return value
109110
@@ -194,6 +195,7 @@ def __new__(
194195 namespace [name ] = prop
195196
196197 cls = cast (ModelMetaclass , super ().__new__ (mcs , cls_name , bases , namespace , ** kwargs ))
198+ cls .__has_pybind_impl__ = True
197199 cls .__pydantic_decorators__ .__annotations__ ["computed_fields" ] = dict [str , Decorator [PropertyFieldInfo ]]
198200 cls .__signature__ = ClassAttribute (
199201 '__signature__' , generate_model_signature (cls .__init__ , field_infos , config_wrapper )
@@ -226,7 +228,13 @@ def json_schema_extra(schema: Dict[str, Any], model_class: ModelMetaclassNoCopy)
226228 properties [alias ] = field_schema
227229
228230
229- class BaseModelNoCopy (BaseModel , __IBaseModelNoCopy , metaclass = ModelMetaclassNoCopy ):
231+ def _from_msg_pack (cls , data : Sequence [int ]):
232+ typ = get_pybind_type (cls )
233+ pybind_impl , _error_code = typ .from_msg_pack (data )
234+ return cls (__pybind_impl__ = pybind_impl )
235+
236+
237+ class BaseModel (PydanticBaseModel , metaclass = ModelMetaclassNoCopy ):
230238 model_config = ConfigDict (json_schema_extra = json_schema_extra )
231239
232240 @property
@@ -237,9 +245,14 @@ def model_computed_fields(self) -> dict[str, PropertyFieldInfo]:
237245 def pybind_impl (self ):
238246 return self .__pybind_impl
239247
240- def __init__ (self , ** kwargs ):
241- super ().__init__ ()
248+ def to_msg_pack (self ):
249+ return self .__pybind_impl .to_msg_pack ()
250+
251+ @classmethod
252+ def from_msg_pack (cls , data : Sequence [int ]):
253+ return _from_msg_pack (cls , data )
242254
255+ def __init__ (self , ** kwargs ):
243256 __pybind_impl__ = kwargs .pop ("__pybind_impl__" , None )
244257 if __pybind_impl__ :
245258 self .__pybind_impl = __pybind_impl__
@@ -266,19 +279,45 @@ def __init__(self, **kwargs):
266279 raise RuntimeError (f"Missing required fields: { missing_required } " )
267280
268281 pybind_type = get_pybind_type (type (self ))
269- self .__pybind_impl = pybind_type (** kwargs )
282+ object .__setattr__ (self , "_BaseModel__pybind_impl" , pybind_type (** kwargs ))
283+
284+ super ().__init__ ()
285+
286+ @property
287+ def __dict__ (self ):
288+ return {name : from_pybind_value (getattr (self , name ), typ ) for name , typ , _ in field_info_iter (type (self ))}
289+
290+ @__dict__ .setter
291+ def __dict__ (self , value : dict ):
292+ try :
293+ object .__getattribute__ (self , "_BaseModel__pybind_impl" )
294+ for name , value in value .items ():
295+ object .__setattr__ (self , name , value )
296+ except AttributeError :
297+ self .__init__ (** value )
270298
271299
272300def __dataclass_init (init ):
273301 @wraps (init )
274302 def wrapper (self , * args , __pybind_impl__ = None , ** kwargs ):
275- self .pybind_impl = __pybind_impl__ or get_pybind_type (type (self ))()
276- init (self , * args , ** kwargs )
277- arse = True
303+ if __pybind_impl__ :
304+ self .__pybind_impl = __pybind_impl__
305+ else :
306+ self .__pybind_impl = get_pybind_type (type (self ))()
307+ init (self , * args , ** kwargs )
278308
279309 return wrapper
280310
281311
312+ def to_msg_pack (self ):
313+ return self .pybind_impl .to_msg_pack ()
314+
315+
316+ @classmethod
317+ def from_msg_pack (cls , data : Sequence [int ]):
318+ return _from_msg_pack (cls , data )
319+
320+
282321def dataclass (cls = None , / , * , init = True , repr = True , eq = True , order = False ,
283322 unsafe_hash = False , frozen = False , match_args = True ,
284323 kw_only = False , slots = False , weakref_slot = False ):
@@ -290,6 +329,10 @@ def dataclass(cls=None, /, *, init=True, repr=True, eq=True, order=False,
290329 setattr (cls , name , property (fget = _getter (name , field .type ), fset = _setter (name , field .type )))
291330
292331 ret .__init__ = __dataclass_init (ret .__init__ )
293- ret .__no_copy__ = True
332+ ret .__has_pybind_impl__ = True
333+
334+ ret .to_msg_pack = to_msg_pack
335+ ret .from_msg_pack = from_msg_pack
336+ ret .pybind_impl = property (fget = lambda self : self .__pybind_impl )
294337
295338 return ret
0 commit comments