1- from dataclasses import is_dataclass
1+ from dataclasses import dataclass as orig_dataclass , is_dataclass
22from enum import Enum , EnumType
3- from functools import cache
3+ from functools import cache , wraps
44from importlib import import_module
55from pydantic import BaseModel as BaseModel , ConfigDict , computed_field
66from pydantic .fields import ComputedFieldInfo , FieldInfo
@@ -84,9 +84,10 @@ def _get_pybind_value(obj, default_to_self: bool = True):
8484 raise UnconvertableValue ("Only dataclasses and pydantic classes supported" )
8585
8686
87- def from_pybind_value (value , typ :Type ):
87+ def from_pybind_value (value , typ : Type ):
8888 origin = get_origin (typ )
8989 args = get_args (typ )
90+ is_dc = is_dataclass (typ )
9091
9192 if origin is Optional :
9293 typ = args [0 ]
@@ -95,9 +96,9 @@ def from_pybind_value(value, typ:Type):
9596
9697 if issubclass (typ , Enum ):
9798 return typ [value .name ]
98- elif issubclass (typ , __IBaseModelNoCopy ):
99+ elif issubclass (typ , __IBaseModelNoCopy ) or ( is_dc and hasattr ( typ , "__no_copy__" )) :
99100 return typ (__pybind_impl__ = value )
100- elif is_dataclass ( typ ) or issubclass (typ , BaseModel ):
101+ elif is_dc or issubclass (typ , BaseModel ):
101102 # This is quite inefficient
102103 kwargs = {}
103104 for field_name , field_type , _ in field_info_iter (typ ):
@@ -266,3 +267,28 @@ def __init__(self, **kwargs):
266267
267268 pybind_type = get_pybind_type (type (self ))
268269 self .__pybind_impl = pybind_type (** kwargs )
270+
271+
272+ def __dc_init (init ):
273+ @wraps (init )
274+ def wrapper (self , * args , __pybind_impl__ = None , ** kwargs ):
275+ self .__pybind_impl = __pybind_impl__ or get_pybind_type (type (self ))()
276+ return init (self , * args , ** kwargs )
277+
278+ return wrapper
279+
280+
281+ def dataclass (cls = None , / , * , init = True , repr = True , eq = True , order = False ,
282+ unsafe_hash = False , frozen = False , match_args = True ,
283+ kw_only = False , slots = False , weakref_slot = False ):
284+
285+ ret = orig_dataclass (cls , init = init , repr = repr , eq = eq , order = order , unsafe_hash = unsafe_hash , frozen = frozen ,
286+ match_args = match_args , kw_only = kw_only , slots = slots , weakref_slot = weakref_slot )
287+
288+ for name , field in ret .__dataclass_fields__ .items ():
289+ setattr (cls , name , property (fget = _getter (name , field .type ), fset = _setter (name , field .type )))
290+
291+ ret .__init__ = __dc_init (ret .__init__ )
292+ ret .__no_copy__ = True
293+
294+ return ret
0 commit comments