Skip to content

Commit ab32ba5

Browse files
rmorsheaogrisel
andauthored
handle dataclass field type sentinels (#513)
Co-authored-by: Olivier Grisel <[email protected]>
1 parent 0f330b6 commit ab32ba5

File tree

3 files changed

+108
-2
lines changed

3 files changed

+108
-2
lines changed

CHANGES.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
2.3.0 (development)
22
===================
33

4-
TODO
5-
4+
- Fix pickling of dataclasses and their instances.
5+
([issue #386](https://github.com/cloudpipe/cloudpickle/issues/386),
6+
[PR #513](https://github.com/cloudpipe/cloudpickle/pull/513))
67

78
2.2.1
89
=====

cloudpickle/cloudpickle_fast.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import _collections_abc
1414
import abc
1515
import copyreg
16+
import dataclasses
1617
import io
1718
import itertools
1819
import logging
@@ -482,6 +483,10 @@ def _odict_items_reduce(obj):
482483
return _make_dict_items, (dict(obj), True)
483484

484485

486+
def _dataclass_field_base_reduce(obj):
487+
return _get_dataclass_field_type_sentinel, (obj.name,)
488+
489+
485490
# COLLECTIONS OF OBJECTS STATE SETTERS
486491
# ------------------------------------
487492
# state setters are called at unpickling time, once the object is created and
@@ -537,6 +542,24 @@ def _class_setstate(obj, state):
537542
return obj
538543

539544

545+
# COLLECTION OF DATACLASS UTILITIES
546+
# ---------------------------------
547+
# There are some internal sentinel values whose identity must be preserved when
548+
# unpickling dataclass fields. Each sentinel value has a unique name that we can
549+
# use to retrieve its identity at unpickling time.
550+
551+
552+
_DATACLASSE_FIELD_TYPE_SENTINELS = {
553+
dataclasses._FIELD.name: dataclasses._FIELD,
554+
dataclasses._FIELD_CLASSVAR.name: dataclasses._FIELD_CLASSVAR,
555+
dataclasses._FIELD_INITVAR.name: dataclasses._FIELD_INITVAR,
556+
}
557+
558+
559+
def _get_dataclass_field_type_sentinel(name):
560+
return _DATACLASSE_FIELD_TYPE_SENTINELS[name]
561+
562+
540563
class CloudPickler(Pickler):
541564
# set of reducers defined and used by cloudpickle (private)
542565
_dispatch_table = {}
@@ -565,6 +588,7 @@ class CloudPickler(Pickler):
565588
_dispatch_table[abc.abstractclassmethod] = _classmethod_reduce
566589
_dispatch_table[abc.abstractstaticmethod] = _classmethod_reduce
567590
_dispatch_table[abc.abstractproperty] = _property_reduce
591+
_dispatch_table[dataclasses._FIELD_BASE] = _dataclass_field_base_reduce
568592

569593
dispatch_table = ChainMap(_dispatch_table, copyreg.dispatch_table)
570594

tests/cloudpickle_test.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import abc
33
import collections
44
import base64
5+
import dataclasses
56
import functools
67
import io
78
import itertools
@@ -2770,6 +2771,86 @@ def func_with_globals():
27702771
"Expected a single deterministic payload, got %d/5" % len(vals)
27712772
)
27722773

2774+
def test_dataclass_fields_are_preserved(self):
2775+
2776+
@dataclasses.dataclass
2777+
class SampleDataclass:
2778+
x: int
2779+
y: dataclasses.InitVar[int]
2780+
z: typing.ClassVar[int]
2781+
2782+
PickledSampleDataclass = pickle_depickle(
2783+
SampleDataclass, protocol=self.protocol
2784+
)
2785+
2786+
found_fields = list(PickledSampleDataclass.__dataclass_fields__.values())
2787+
assert set(f.name for f in found_fields) == {
2788+
"x", "y", "z"
2789+
}
2790+
2791+
expected_ftypes = {
2792+
"x": dataclasses._FIELD,
2793+
"y": dataclasses._FIELD_INITVAR,
2794+
"z": dataclasses._FIELD_CLASSVAR,
2795+
}
2796+
2797+
for f in found_fields:
2798+
assert f._field_type is expected_ftypes[f.name]
2799+
2800+
def test_interactively_defined_dataclass_with_initvar_and_classvar(self):
2801+
code = """if __name__ == "__main__":
2802+
import dataclasses
2803+
from testutils import subprocess_worker
2804+
import typing
2805+
2806+
with subprocess_worker(protocol={protocol}) as w:
2807+
2808+
@dataclasses.dataclass
2809+
class SampleDataclass:
2810+
x: int
2811+
y: dataclasses.InitVar[int] = None
2812+
z: typing.ClassVar[int] = 42
2813+
2814+
def __post_init__(self, y=0):
2815+
self.x += y
2816+
2817+
def large_enough(self):
2818+
return self.x > self.z
2819+
2820+
value = SampleDataclass(2, y=2)
2821+
2822+
def check_dataclass_instance(value):
2823+
assert isinstance(value, SampleDataclass)
2824+
assert value.x == 4
2825+
assert value.z == 42
2826+
expected_dict = dict(x=4)
2827+
assert dataclasses.asdict(value) == expected_dict
2828+
assert not value.large_enough()
2829+
try:
2830+
SampleDataclass.z = 0
2831+
assert value.z == 0
2832+
assert value.large_enough()
2833+
finally:
2834+
SampleDataclass.z = 42
2835+
return "ok"
2836+
2837+
assert check_dataclass_instance(value) == "ok"
2838+
2839+
# Check that this instance of an interactively defined dataclass
2840+
# behavesconsistently in a remote worker process:
2841+
assert w.run(check_dataclass_instance, value) == "ok"
2842+
2843+
# Check class provenance tracking is not impacted by the
2844+
# @dataclass decorator:
2845+
def echo(*args):
2846+
return args
2847+
2848+
cloned_value, cloned_type = w.run(echo, value, SampleDataclass)
2849+
assert cloned_type is SampleDataclass
2850+
assert isinstance(cloned_value, SampleDataclass)
2851+
""".format(protocol=self.protocol)
2852+
assert_run_python_script(code)
2853+
27732854

27742855
class Protocol2CloudPickleTest(CloudPickleTest):
27752856

0 commit comments

Comments
 (0)