diff --git a/cloudpickle/cloudpickle.py b/cloudpickle/cloudpickle.py index 347b3869..6812ad61 100644 --- a/cloudpickle/cloudpickle.py +++ b/cloudpickle/cloudpickle.py @@ -55,7 +55,6 @@ import warnings from .compat import pickle -from collections import OrderedDict from typing import Generic, Union, Tuple, Callable from pickle import _getattribute from importlib._bootstrap import _find_spec @@ -952,22 +951,31 @@ def _get_bases(typ): return getattr(typ, bases_attr) -def _make_dict_keys(obj, is_ordered=False): - if is_ordered: - return OrderedDict.fromkeys(obj).keys() +def _make_keys_view(obj, typ): + t = typ or dict + if hasattr(t, "fromkeys"): + o = t.fromkeys(obj) else: - return dict.fromkeys(obj).keys() + o = dict.fromkeys(obj) + try: + o = t(o) + except Exception: + pass + return o.keys() -def _make_dict_values(obj, is_ordered=False): - if is_ordered: - return OrderedDict((i, _) for i, _ in enumerate(obj)).values() - else: - return {i: _ for i, _ in enumerate(obj)}.values() +def _make_values_view(obj, typ): + t = typ or dict + o = t({i: _ for i, _ in enumerate(obj)}) + return o.values() -def _make_dict_items(obj, is_ordered=False): - if is_ordered: - return OrderedDict(obj).items() - else: - return obj.items() +def _make_items_view(obj, typ): + t = typ or dict + try: + o = t(obj) + except Exception: + o = dict(obj) + if hasattr(o, "items"): + return o.items() + return o diff --git a/cloudpickle/cloudpickle_fast.py b/cloudpickle/cloudpickle_fast.py index 6db059eb..2acdad20 100644 --- a/cloudpickle/cloudpickle_fast.py +++ b/cloudpickle/cloudpickle_fast.py @@ -10,7 +10,6 @@ guards present in cloudpickle.py that were written to handle PyPy specificities are not present in cloudpickle_fast.py """ -import _collections_abc import abc import copyreg import io @@ -23,7 +22,8 @@ import typing from enum import Enum -from collections import ChainMap, OrderedDict +from collections import ChainMap, OrderedDict, UserDict, namedtuple +from typing import Iterable from .compat import pickle, Pickler from .cloudpickle import ( @@ -35,7 +35,7 @@ _is_parametrized_type_hint, PYPY, cell_set, parametrized_type_hint_getinitargs, _create_parametrized_type_hint, builtin_code_type, - _make_dict_keys, _make_dict_values, _make_dict_items, + _make_keys_view, _make_values_view, _make_items_view, ) @@ -418,41 +418,66 @@ def _class_reduce(obj): return _dynamic_class_reduce(obj) return NotImplemented +# DICT VIEWS TYPES -def _dict_keys_reduce(obj): - # Safer not to ship the full dict as sending the rest might - # be unintended and could potentially cause leaking of - # sensitive information - return _make_dict_keys, (list(obj), ) +ViewInfo = namedtuple("ViewInfo", ["view", "packer", "maker", "object_type"]) # noqa -def _dict_values_reduce(obj): - # Safer not to ship the full dict as sending the rest might - # be unintended and could potentially cause leaking of - # sensitive information - return _make_dict_values, (list(obj), ) +_VIEW_ATTRS_INFO = [ + ViewInfo("keys", list, _make_keys_view, None), + ViewInfo("values", list, _make_values_view, None), + ViewInfo("items", dict, _make_items_view, None), +] -def _dict_items_reduce(obj): - return _make_dict_items, (dict(obj), ) +_VIEWS_TYPES_TABLE = {} -def _odict_keys_reduce(obj): - # Safer not to ship the full dict as sending the rest might - # be unintended and could potentially cause leaking of - # sensitive information - return _make_dict_keys, (list(obj), True) +def register_views_types(types, attr_info=None, overwrite=False): + """Register views types, returns a copy.""" + if not isinstance(types, Iterable): + types = [types] + if attr_info is None: + attr_info = _VIEW_ATTRS_INFO.copy() + elif isinstance(attr_info, ViewInfo): + attr_info = [attr_info] + for typ in types: + for info in attr_info: + if isinstance(typ, type): + object_type = typ + obj_instance = object_type() + else: + object_type = type(typ) + obj_instance = typ + a = getattr(obj_instance, info.view, None) + if callable(a): + view_obj = a() + else: + view_obj = a + if a is None: + continue + view_type = type(view_obj) + if view_type not in _VIEWS_TYPES_TABLE or overwrite: + _VIEWS_TYPES_TABLE[view_type] = ViewInfo( + info.view, + info.packer, + info.maker, + object_type, + ) + return _VIEWS_TYPES_TABLE -def _odict_values_reduce(obj): - # Safer not to ship the full dict as sending the rest might - # be unintended and could potentially cause leaking of - # sensitive information - return _make_dict_values, (list(obj), True) +_VIEWABLE_TYPES = [dict, OrderedDict, UserDict] +register_views_types(_VIEWABLE_TYPES) -def _odict_items_reduce(obj): - return _make_dict_items, (dict(obj), True) +def _views_reducer(obj): + typ = type(obj) + info = _VIEWS_TYPES_TABLE[typ] + return info.maker, (info.packer(obj), info.object_type) + + +_VIEWS_DISPATCH_TABLE = {k: _views_reducer for k in _VIEWS_TYPES_TABLE.keys()} # COLLECTIONS OF OBJECTS STATE SETTERS @@ -528,13 +553,7 @@ class CloudPickler(Pickler): _dispatch_table[types.MappingProxyType] = _mappingproxy_reduce _dispatch_table[weakref.WeakSet] = _weakset_reduce _dispatch_table[typing.TypeVar] = _typevar_reduce - _dispatch_table[_collections_abc.dict_keys] = _dict_keys_reduce - _dispatch_table[_collections_abc.dict_values] = _dict_values_reduce - _dispatch_table[_collections_abc.dict_items] = _dict_items_reduce - _dispatch_table[type(OrderedDict().keys())] = _odict_keys_reduce - _dispatch_table[type(OrderedDict().values())] = _odict_values_reduce - _dispatch_table[type(OrderedDict().items())] = _odict_items_reduce - + _dispatch_table.update(_VIEWS_DISPATCH_TABLE) dispatch_table = ChainMap(_dispatch_table, copyreg.dispatch_table) diff --git a/tests/cloudpickle_test.py b/tests/cloudpickle_test.py index d2acfb71..c371d10e 100644 --- a/tests/cloudpickle_test.py +++ b/tests/cloudpickle_test.py @@ -1,6 +1,5 @@ from __future__ import division -import _collections_abc import abc import collections import base64 @@ -53,14 +52,13 @@ from cloudpickle.cloudpickle import _make_empty_cell, cell_set from cloudpickle.cloudpickle import _extract_class_dict, _whichmodule from cloudpickle.cloudpickle import _lookup_module_and_qualname +from cloudpickle.cloudpickle_fast import _VIEWS_TYPES_TABLE from .testutils import subprocess_pickle_echo from .testutils import subprocess_pickle_string from .testutils import assert_run_python_script from .testutils import subprocess_worker -from _cloudpickle_testpkg import relative_imports_factory - _TEST_GLOBAL_VARIABLE = "default_value" _TEST_GLOBAL_VARIABLE2 = "another_value" @@ -224,40 +222,53 @@ def test_memoryview(self): buffer_obj.tobytes()) def test_dict_keys(self): - keys = {"a": 1, "b": 2}.keys() - results = pickle_depickle(keys) - self.assertEqual(results, keys) - assert isinstance(results, _collections_abc.dict_keys) + data = {"a": 1, "b": 2}.keys() + results = pickle_depickle(data) + self.assertEqual(results, data) + assert isinstance(results, type(data)) def test_dict_values(self): - values = {"a": 1, "b": 2}.values() - results = pickle_depickle(values) - self.assertEqual(sorted(results), sorted(values)) - assert isinstance(results, _collections_abc.dict_values) + data = {"a": 1, "b": 2}.values() + results = pickle_depickle(data) + self.assertEqual(sorted(results), sorted(data)) + assert isinstance(results, type(data)) def test_dict_items(self): - items = {"a": 1, "b": 2}.items() - results = pickle_depickle(items) - self.assertEqual(results, items) - assert isinstance(results, _collections_abc.dict_items) + data = {"a": 1, "b": 2}.items() + results = pickle_depickle(data) + self.assertEqual(results, data) + assert isinstance(results, type(data)) def test_odict_keys(self): - keys = collections.OrderedDict([("a", 1), ("b", 2)]).keys() - results = pickle_depickle(keys) - self.assertEqual(results, keys) - assert type(keys) == type(results) + data = collections.OrderedDict([("a", 1), ("b", 2)]).keys() + results = pickle_depickle(data) + self.assertEqual(results, data) + assert isinstance(results, type(data)) def test_odict_values(self): - values = collections.OrderedDict([("a", 1), ("b", 2)]).values() - results = pickle_depickle(values) - self.assertEqual(list(results), list(values)) - assert type(values) == type(results) + data = collections.OrderedDict([("a", 1), ("b", 2)]).values() + results = pickle_depickle(data) + self.assertEqual(list(results), list(data)) + assert isinstance(results, type(data)) def test_odict_items(self): - items = collections.OrderedDict([("a", 1), ("b", 2)]).items() - results = pickle_depickle(items) - self.assertEqual(results, items) - assert type(items) == type(results) + data = collections.OrderedDict([("a", 1), ("b", 2)]).items() + results = pickle_depickle(data) + self.assertEqual(results, data) + assert isinstance(results, type(data)) + + def test_view_types(self): + for view_type, info in _VIEWS_TYPES_TABLE.items(): + obj = info.object_type([("a", 1), ("b", 2)]) + at = getattr(obj, info.view, None) + if at is not None: + if callable(at): + data = at() + else: + data = at + results = pickle_depickle(data) + assert isinstance(results, view_type) + self.assertEqual(list(results), list(data)) def test_sliced_and_non_contiguous_memoryview(self): buffer_obj = memoryview(b"Hello!" * 3)[2:15:2] @@ -760,6 +771,7 @@ def test_module_importability(self): # their parent modules are considered importable by cloudpickle. # See the mod_with_dynamic_submodule documentation for more # details of this use case. + _cloudpickle_testpkg = pytest.importorskip("_cloudpickle_testpkg") # noqa F841 import _cloudpickle_testpkg.mod.dynamic_submodule as m assert _should_pickle_by_reference(m) assert pickle_depickle(m, protocol=self.protocol) is m @@ -2066,7 +2078,8 @@ def test_relative_import_inside_function(self): # Make sure relative imports inside round-tripped functions is not # broken. This was a bug in cloudpickle versions <= 0.5.3 and was # re-introduced in 0.8.0. - f, g = relative_imports_factory() + _cloudpickle_testpkg = pytest.importorskip("_cloudpickle_testpkg") # noqa F841 + f, g = _cloudpickle_testpkg.relative_imports_factory() for func, source in zip([f, g], ["module", "package"]): # Make sure relative imports are initially working assert func() == "hello from a {}!".format(source) @@ -2115,6 +2128,7 @@ def f(a, /, b=1): def test___reduce___returns_string(self): # Non regression test for objects with a __reduce__ method returning a # string, meaning "save by attribute using save_global" + _cloudpickle_testpkg = pytest.importorskip("_cloudpickle_testpkg") # noqa F841 from _cloudpickle_testpkg import some_singleton assert some_singleton.__reduce__() == "some_singleton" depickled_singleton = pickle_depickle( @@ -2187,6 +2201,7 @@ def test_pickle_dynamic_typevar_memoization(self): assert depickled_T1 is depickled_T2 def test_pickle_importable_typevar(self): + _cloudpickle_testpkg = pytest.importorskip("_cloudpickle_testpkg") # noqa F841 from _cloudpickle_testpkg import T T1 = pickle_depickle(T, protocol=self.protocol) assert T1 is T @@ -2512,6 +2527,7 @@ def test_pickle_constructs_from_module_registered_for_pickling_by_value(self): def test_pickle_constructs_from_installed_packages_registered_for_pickling_by_value( # noqa self ): + _cloudpickle_testpkg = pytest.importorskip("_cloudpickle_testpkg") # noqa F841 for package_or_module in ["package", "module"]: if package_or_module == "package": import _cloudpickle_testpkg as m @@ -2546,7 +2562,7 @@ def test_pickle_various_versions_of_the_same_function_with_different_pickling_me # pickled in a different way - by value and/or by reference) can # peacefully co-exist (e.g. without globals interaction) in a remote # worker. - import _cloudpickle_testpkg + _cloudpickle_testpkg = pytest.importorskip("_cloudpickle_testpkg") # noqa F841 from _cloudpickle_testpkg import package_function_with_global as f _original_global = _cloudpickle_testpkg.global_variable @@ -2623,7 +2639,7 @@ def test_lookup_module_and_qualname_dynamic_typevar(): def test_lookup_module_and_qualname_importable_typevar(): - import _cloudpickle_testpkg + _cloudpickle_testpkg = pytest.importorskip("_cloudpickle_testpkg") # noqa F841 T = _cloudpickle_testpkg.T module_and_name = _lookup_module_and_qualname(T, name=T.__name__) assert module_and_name is not None @@ -2642,6 +2658,7 @@ def test_lookup_module_and_qualname_stdlib_typevar(): def test_register_pickle_by_value(): + _cloudpickle_testpkg = pytest.importorskip("_cloudpickle_testpkg") # noqa F841 import _cloudpickle_testpkg as pkg import _cloudpickle_testpkg.mod as mod