Skip to content

Commit 7b39334

Browse files
Jan HosangThe dataclass_array Authors
authored andcommitted
Support pickle protocol for DataclassArray.
PiperOrigin-RevId: 675147723
1 parent ebe7df4 commit 7b39334

File tree

3 files changed

+50
-3
lines changed

3 files changed

+50
-3
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`):
2323

2424
## [Unreleased]
2525

26+
* Add pickle support to `DataclassArray`.
27+
2628
## [1.5.2] - 2024-03-19
2729

2830
* Drop Python 3.10 support

dataclass_array/array_dataclass.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -798,6 +798,21 @@ def tree_unflatten(
798798
array_field_values: list[DcOrArray],
799799
) -> _DcT:
800800
"""`jax.tree_utils` support."""
801+
return cls._tree_unflatten(
802+
metadata=metadata,
803+
array_field_values=array_field_values,
804+
constructor=cls,
805+
)
806+
807+
@classmethod
808+
def _tree_unflatten(
809+
cls: Type[_DcT],
810+
*,
811+
metadata: _TreeMetadata,
812+
array_field_values: list[DcOrArray],
813+
constructor: Callable[..., _DcT],
814+
) -> _DcT:
815+
"""Initialize a model after serialization."""
801816
array_field_kwargs = dict(
802817
zip(
803818
metadata.array_field_names,
@@ -813,7 +828,7 @@ def tree_unflatten(
813828
else:
814829
non_init_fields[k] = v
815830

816-
self = cls(**array_field_kwargs, **init_fields)
831+
self = constructor(**array_field_kwargs, **init_fields)
817832
# Currently it's not clear how to handle non-init fields so raise an error
818833
if non_init_fields:
819834
if set(non_init_fields) - self.__dca_non_init_fields__:
@@ -844,6 +859,23 @@ def __tf_unflatten__(
844859
) -> _DcT:
845860
return cls.tree_unflatten(metadata, components)
846861

862+
def __getstate__(self) -> tuple[_TreeMetadata, tuple[DcOrArray, ...]]:
863+
components, metadata = self._tree_flatten()
864+
return metadata, components
865+
866+
def __setstate__(self, state: tuple[_TreeMetadata, list[DcOrArray]]) -> None:
867+
868+
def constructor(**kwargs):
869+
self.__init__(**kwargs)
870+
return self
871+
872+
metadata, components = state
873+
type(self)._tree_unflatten(
874+
metadata=metadata,
875+
array_field_values=components,
876+
constructor=constructor,
877+
)
878+
847879
def _setattr(self, name: str, value: Any) -> None:
848880
"""Like setattr, but support `frozen` dataclasses."""
849881
object.__setattr__(self, name, value)
@@ -1023,8 +1055,8 @@ class _ArrayFieldMetadata:
10231055
Attributes:
10241056
inner_shape_non_static: Inner shape. Can contain non-static dims (e.g.
10251057
`(None, 3)`)
1026-
dtype: Type of the array. Can be `enp.dtypes.DType` or
1027-
`dca.DataclassArray` for nested arrays.
1058+
dtype: Type of the array. Can be `enp.dtypes.DType` or `dca.DataclassArray`
1059+
for nested arrays.
10281060
"""
10291061

10301062
inner_shape_non_static: DynamicShape

dataclass_array/array_dataclass_test.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from __future__ import annotations
1818

19+
import pickle
1920
from typing import Any
2021

2122
import dataclass_array as dca
@@ -769,3 +770,15 @@ def test_class_getitem():
769770
assert Point[''] != Point
770771
assert Point[''] != Point['h w']
771772
assert Point[''] != Isometrie['']
773+
774+
775+
@enp.testing.parametrize_xnp()
776+
@pytest.mark.parametrize('batch_shape', [(), (1, 3)])
777+
def test_dataclass_pickle_unpickle(xnp: enp.NpModule, batch_shape: Shape):
778+
expected = Point(
779+
x=xnp.zeros(batch_shape, dtype=xnp.float32),
780+
y=xnp.zeros(batch_shape, dtype=xnp.float32),
781+
)
782+
buffer = pickle.dumps(expected)
783+
actual = pickle.loads(buffer)
784+
dca.testing.assert_array_equal(actual, expected)

0 commit comments

Comments
 (0)