|
2 | 2 | import abc |
3 | 3 | import collections |
4 | 4 | import base64 |
| 5 | +import dataclasses |
5 | 6 | import functools |
6 | 7 | import io |
7 | 8 | import itertools |
@@ -2770,6 +2771,86 @@ def func_with_globals(): |
2770 | 2771 | "Expected a single deterministic payload, got %d/5" % len(vals) |
2771 | 2772 | ) |
2772 | 2773 |
|
| 2774 | + def test_dataclass_fields_are_preserved(self): |
| 2775 | + |
| 2776 | + @dataclasses.dataclass |
| 2777 | + class SampleDataclass: |
| 2778 | + x: int |
| 2779 | + y: dataclasses.InitVar[int] |
| 2780 | + z: typing.ClassVar[int] |
| 2781 | + |
| 2782 | + PickledSampleDataclass = pickle_depickle( |
| 2783 | + SampleDataclass, protocol=self.protocol |
| 2784 | + ) |
| 2785 | + |
| 2786 | + found_fields = list(PickledSampleDataclass.__dataclass_fields__.values()) |
| 2787 | + assert set(f.name for f in found_fields) == { |
| 2788 | + "x", "y", "z" |
| 2789 | + } |
| 2790 | + |
| 2791 | + expected_ftypes = { |
| 2792 | + "x": dataclasses._FIELD, |
| 2793 | + "y": dataclasses._FIELD_INITVAR, |
| 2794 | + "z": dataclasses._FIELD_CLASSVAR, |
| 2795 | + } |
| 2796 | + |
| 2797 | + for f in found_fields: |
| 2798 | + assert f._field_type is expected_ftypes[f.name] |
| 2799 | + |
| 2800 | + def test_interactively_defined_dataclass_with_initvar_and_classvar(self): |
| 2801 | + code = """if __name__ == "__main__": |
| 2802 | + import dataclasses |
| 2803 | + from testutils import subprocess_worker |
| 2804 | + import typing |
| 2805 | +
|
| 2806 | + with subprocess_worker(protocol={protocol}) as w: |
| 2807 | +
|
| 2808 | + @dataclasses.dataclass |
| 2809 | + class SampleDataclass: |
| 2810 | + x: int |
| 2811 | + y: dataclasses.InitVar[int] = None |
| 2812 | + z: typing.ClassVar[int] = 42 |
| 2813 | +
|
| 2814 | + def __post_init__(self, y=0): |
| 2815 | + self.x += y |
| 2816 | +
|
| 2817 | + def large_enough(self): |
| 2818 | + return self.x > self.z |
| 2819 | +
|
| 2820 | + value = SampleDataclass(2, y=2) |
| 2821 | +
|
| 2822 | + def check_dataclass_instance(value): |
| 2823 | + assert isinstance(value, SampleDataclass) |
| 2824 | + assert value.x == 4 |
| 2825 | + assert value.z == 42 |
| 2826 | + expected_dict = dict(x=4) |
| 2827 | + assert dataclasses.asdict(value) == expected_dict |
| 2828 | + assert not value.large_enough() |
| 2829 | + try: |
| 2830 | + SampleDataclass.z = 0 |
| 2831 | + assert value.z == 0 |
| 2832 | + assert value.large_enough() |
| 2833 | + finally: |
| 2834 | + SampleDataclass.z = 42 |
| 2835 | + return "ok" |
| 2836 | +
|
| 2837 | + assert check_dataclass_instance(value) == "ok" |
| 2838 | +
|
| 2839 | + # Check that this instance of an interactively defined dataclass |
| 2840 | + # behavesconsistently in a remote worker process: |
| 2841 | + assert w.run(check_dataclass_instance, value) == "ok" |
| 2842 | +
|
| 2843 | + # Check class provenance tracking is not impacted by the |
| 2844 | + # @dataclass decorator: |
| 2845 | + def echo(*args): |
| 2846 | + return args |
| 2847 | +
|
| 2848 | + cloned_value, cloned_type = w.run(echo, value, SampleDataclass) |
| 2849 | + assert cloned_type is SampleDataclass |
| 2850 | + assert isinstance(cloned_value, SampleDataclass) |
| 2851 | + """.format(protocol=self.protocol) |
| 2852 | + assert_run_python_script(code) |
| 2853 | + |
2773 | 2854 |
|
2774 | 2855 | class Protocol2CloudPickleTest(CloudPickleTest): |
2775 | 2856 |
|
|
0 commit comments