@@ -798,6 +798,21 @@ def tree_unflatten(
798798 array_field_values : list [DcOrArray ],
799799 ) -> _DcT :
800800 """`jax.tree_utils` support."""
801+ return cls ._tree_unflatten (
802+ metadata = metadata ,
803+ array_field_values = array_field_values ,
804+ constructor = cls ,
805+ )
806+
807+ @classmethod
808+ def _tree_unflatten (
809+ cls : Type [_DcT ],
810+ * ,
811+ metadata : _TreeMetadata ,
812+ array_field_values : list [DcOrArray ],
813+ constructor : Callable [..., _DcT ],
814+ ) -> _DcT :
815+ """Initialize a model after serialization."""
801816 array_field_kwargs = dict (
802817 zip (
803818 metadata .array_field_names ,
@@ -813,7 +828,7 @@ def tree_unflatten(
813828 else :
814829 non_init_fields [k ] = v
815830
816- self = cls (** array_field_kwargs , ** init_fields )
831+ self = constructor (** array_field_kwargs , ** init_fields )
817832 # Currently it's not clear how to handle non-init fields so raise an error
818833 if non_init_fields :
819834 if set (non_init_fields ) - self .__dca_non_init_fields__ :
@@ -844,6 +859,23 @@ def __tf_unflatten__(
844859 ) -> _DcT :
845860 return cls .tree_unflatten (metadata , components )
846861
862+ def __getstate__ (self ) -> tuple [_TreeMetadata , tuple [DcOrArray , ...]]:
863+ components , metadata = self ._tree_flatten ()
864+ return metadata , components
865+
866+ def __setstate__ (self , state : tuple [_TreeMetadata , list [DcOrArray ]]) -> None :
867+
868+ def constructor (** kwargs ):
869+ self .__init__ (** kwargs )
870+ return self
871+
872+ metadata , components = state
873+ type (self )._tree_unflatten (
874+ metadata = metadata ,
875+ array_field_values = components ,
876+ constructor = constructor ,
877+ )
878+
847879 def _setattr (self , name : str , value : Any ) -> None :
848880 """Like setattr, but support `frozen` dataclasses."""
849881 object .__setattr__ (self , name , value )
@@ -1023,8 +1055,8 @@ class _ArrayFieldMetadata:
10231055 Attributes:
10241056 inner_shape_non_static: Inner shape. Can contain non-static dims (e.g.
10251057 `(None, 3)`)
1026- dtype: Type of the array. Can be `enp.dtypes.DType` or
1027- `dca.DataclassArray` for nested arrays.
1058+ dtype: Type of the array. Can be `enp.dtypes.DType` or `dca.DataclassArray`
1059+ for nested arrays.
10281060 """
10291061
10301062 inner_shape_non_static : DynamicShape
0 commit comments