Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 23 additions & 15 deletions cloudpickle/cloudpickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@
import warnings

from .compat import pickle
from collections import OrderedDict
from typing import Generic, Union, Tuple, Callable
from pickle import _getattribute
from importlib._bootstrap import _find_spec
Expand Down Expand Up @@ -952,22 +951,31 @@ def _get_bases(typ):
return getattr(typ, bases_attr)


def _make_dict_keys(obj, is_ordered=False):
if is_ordered:
return OrderedDict.fromkeys(obj).keys()
def _make_keys_view(obj, typ):
t = typ or dict
if hasattr(t, "fromkeys"):
o = t.fromkeys(obj)
else:
return dict.fromkeys(obj).keys()
o = dict.fromkeys(obj)
try:
o = t(o)
except Exception:
pass
return o.keys()


def _make_dict_values(obj, is_ordered=False):
if is_ordered:
return OrderedDict((i, _) for i, _ in enumerate(obj)).values()
else:
return {i: _ for i, _ in enumerate(obj)}.values()
def _make_values_view(obj, typ):
t = typ or dict
o = t({i: _ for i, _ in enumerate(obj)})
return o.values()


def _make_dict_items(obj, is_ordered=False):
if is_ordered:
return OrderedDict(obj).items()
else:
return obj.items()
def _make_items_view(obj, typ):
t = typ or dict
try:
o = t(obj)
except Exception:
o = dict(obj)
if hasattr(o, "items"):
return o.items()
return o
87 changes: 53 additions & 34 deletions cloudpickle/cloudpickle_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
guards present in cloudpickle.py that were written to handle PyPy specificities
are not present in cloudpickle_fast.py
"""
import _collections_abc
import abc
import copyreg
import io
Expand All @@ -23,7 +22,8 @@
import typing

from enum import Enum
from collections import ChainMap, OrderedDict
from collections import ChainMap, OrderedDict, UserDict, namedtuple
from typing import Iterable

from .compat import pickle, Pickler
from .cloudpickle import (
Expand All @@ -35,7 +35,7 @@
_is_parametrized_type_hint, PYPY, cell_set,
parametrized_type_hint_getinitargs, _create_parametrized_type_hint,
builtin_code_type,
_make_dict_keys, _make_dict_values, _make_dict_items,
_make_keys_view, _make_values_view, _make_items_view,
)


Expand Down Expand Up @@ -418,41 +418,66 @@ def _class_reduce(obj):
return _dynamic_class_reduce(obj)
return NotImplemented

# DICT VIEWS TYPES

def _dict_keys_reduce(obj):
# Safer not to ship the full dict as sending the rest might
# be unintended and could potentially cause leaking of
# sensitive information
return _make_dict_keys, (list(obj), )
ViewInfo = namedtuple("ViewInfo", ["view", "packer", "maker", "object_type"]) # noqa


def _dict_values_reduce(obj):
# Safer not to ship the full dict as sending the rest might
# be unintended and could potentially cause leaking of
# sensitive information
return _make_dict_values, (list(obj), )
_VIEW_ATTRS_INFO = [
ViewInfo("keys", list, _make_keys_view, None),
ViewInfo("values", list, _make_values_view, None),
ViewInfo("items", dict, _make_items_view, None),
]


def _dict_items_reduce(obj):
return _make_dict_items, (dict(obj), )
_VIEWS_TYPES_TABLE = {}


def _odict_keys_reduce(obj):
# Safer not to ship the full dict as sending the rest might
# be unintended and could potentially cause leaking of
# sensitive information
return _make_dict_keys, (list(obj), True)
def register_views_types(types, attr_info=None, overwrite=False):
"""Register views types, returns a copy."""
if not isinstance(types, Iterable):
types = [types]
if attr_info is None:
attr_info = _VIEW_ATTRS_INFO.copy()
elif isinstance(attr_info, ViewInfo):
attr_info = [attr_info]
for typ in types:
for info in attr_info:
if isinstance(typ, type):
object_type = typ
obj_instance = object_type()
else:
object_type = type(typ)
obj_instance = typ
a = getattr(obj_instance, info.view, None)
if callable(a):
view_obj = a()
else:
view_obj = a
if a is None:
continue
view_type = type(view_obj)
if view_type not in _VIEWS_TYPES_TABLE or overwrite:
_VIEWS_TYPES_TABLE[view_type] = ViewInfo(
info.view,
info.packer,
info.maker,
object_type,
)
return _VIEWS_TYPES_TABLE


def _odict_values_reduce(obj):
# Safer not to ship the full dict as sending the rest might
# be unintended and could potentially cause leaking of
# sensitive information
return _make_dict_values, (list(obj), True)
_VIEWABLE_TYPES = [dict, OrderedDict, UserDict]
register_views_types(_VIEWABLE_TYPES)


def _odict_items_reduce(obj):
return _make_dict_items, (dict(obj), True)
def _views_reducer(obj):
typ = type(obj)
info = _VIEWS_TYPES_TABLE[typ]
return info.maker, (info.packer(obj), info.object_type)


_VIEWS_DISPATCH_TABLE = {k: _views_reducer for k in _VIEWS_TYPES_TABLE.keys()}


# COLLECTIONS OF OBJECTS STATE SETTERS
Expand Down Expand Up @@ -528,13 +553,7 @@ class CloudPickler(Pickler):
_dispatch_table[types.MappingProxyType] = _mappingproxy_reduce
_dispatch_table[weakref.WeakSet] = _weakset_reduce
_dispatch_table[typing.TypeVar] = _typevar_reduce
_dispatch_table[_collections_abc.dict_keys] = _dict_keys_reduce
_dispatch_table[_collections_abc.dict_values] = _dict_values_reduce
_dispatch_table[_collections_abc.dict_items] = _dict_items_reduce
_dispatch_table[type(OrderedDict().keys())] = _odict_keys_reduce
_dispatch_table[type(OrderedDict().values())] = _odict_values_reduce
_dispatch_table[type(OrderedDict().items())] = _odict_items_reduce

_dispatch_table.update(_VIEWS_DISPATCH_TABLE)

dispatch_table = ChainMap(_dispatch_table, copyreg.dispatch_table)

Expand Down
77 changes: 47 additions & 30 deletions tests/cloudpickle_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import division

import _collections_abc
import abc
import collections
import base64
Expand Down Expand Up @@ -53,14 +52,13 @@
from cloudpickle.cloudpickle import _make_empty_cell, cell_set
from cloudpickle.cloudpickle import _extract_class_dict, _whichmodule
from cloudpickle.cloudpickle import _lookup_module_and_qualname
from cloudpickle.cloudpickle_fast import _VIEWS_TYPES_TABLE

from .testutils import subprocess_pickle_echo
from .testutils import subprocess_pickle_string
from .testutils import assert_run_python_script
from .testutils import subprocess_worker

from _cloudpickle_testpkg import relative_imports_factory


_TEST_GLOBAL_VARIABLE = "default_value"
_TEST_GLOBAL_VARIABLE2 = "another_value"
Expand Down Expand Up @@ -224,40 +222,53 @@ def test_memoryview(self):
buffer_obj.tobytes())

def test_dict_keys(self):
keys = {"a": 1, "b": 2}.keys()
results = pickle_depickle(keys)
self.assertEqual(results, keys)
assert isinstance(results, _collections_abc.dict_keys)
data = {"a": 1, "b": 2}.keys()
results = pickle_depickle(data)
self.assertEqual(results, data)
assert isinstance(results, type(data))

def test_dict_values(self):
values = {"a": 1, "b": 2}.values()
results = pickle_depickle(values)
self.assertEqual(sorted(results), sorted(values))
assert isinstance(results, _collections_abc.dict_values)
data = {"a": 1, "b": 2}.values()
results = pickle_depickle(data)
self.assertEqual(sorted(results), sorted(data))
assert isinstance(results, type(data))

def test_dict_items(self):
items = {"a": 1, "b": 2}.items()
results = pickle_depickle(items)
self.assertEqual(results, items)
assert isinstance(results, _collections_abc.dict_items)
data = {"a": 1, "b": 2}.items()
results = pickle_depickle(data)
self.assertEqual(results, data)
assert isinstance(results, type(data))

def test_odict_keys(self):
keys = collections.OrderedDict([("a", 1), ("b", 2)]).keys()
results = pickle_depickle(keys)
self.assertEqual(results, keys)
assert type(keys) == type(results)
data = collections.OrderedDict([("a", 1), ("b", 2)]).keys()
results = pickle_depickle(data)
self.assertEqual(results, data)
assert isinstance(results, type(data))

def test_odict_values(self):
values = collections.OrderedDict([("a", 1), ("b", 2)]).values()
results = pickle_depickle(values)
self.assertEqual(list(results), list(values))
assert type(values) == type(results)
data = collections.OrderedDict([("a", 1), ("b", 2)]).values()
results = pickle_depickle(data)
self.assertEqual(list(results), list(data))
assert isinstance(results, type(data))

def test_odict_items(self):
items = collections.OrderedDict([("a", 1), ("b", 2)]).items()
results = pickle_depickle(items)
self.assertEqual(results, items)
assert type(items) == type(results)
data = collections.OrderedDict([("a", 1), ("b", 2)]).items()
results = pickle_depickle(data)
self.assertEqual(results, data)
assert isinstance(results, type(data))

def test_view_types(self):
for view_type, info in _VIEWS_TYPES_TABLE.items():
obj = info.object_type([("a", 1), ("b", 2)])
at = getattr(obj, info.view, None)
if at is not None:
if callable(at):
data = at()
else:
data = at
results = pickle_depickle(data)
assert isinstance(results, view_type)
self.assertEqual(list(results), list(data))

def test_sliced_and_non_contiguous_memoryview(self):
buffer_obj = memoryview(b"Hello!" * 3)[2:15:2]
Expand Down Expand Up @@ -760,6 +771,7 @@ def test_module_importability(self):
# their parent modules are considered importable by cloudpickle.
# See the mod_with_dynamic_submodule documentation for more
# details of this use case.
_cloudpickle_testpkg = pytest.importorskip("_cloudpickle_testpkg") # noqa F841
import _cloudpickle_testpkg.mod.dynamic_submodule as m
assert _should_pickle_by_reference(m)
assert pickle_depickle(m, protocol=self.protocol) is m
Expand Down Expand Up @@ -2066,7 +2078,8 @@ def test_relative_import_inside_function(self):
# Make sure relative imports inside round-tripped functions is not
# broken. This was a bug in cloudpickle versions <= 0.5.3 and was
# re-introduced in 0.8.0.
f, g = relative_imports_factory()
_cloudpickle_testpkg = pytest.importorskip("_cloudpickle_testpkg") # noqa F841
f, g = _cloudpickle_testpkg.relative_imports_factory()
for func, source in zip([f, g], ["module", "package"]):
# Make sure relative imports are initially working
assert func() == "hello from a {}!".format(source)
Expand Down Expand Up @@ -2115,6 +2128,7 @@ def f(a, /, b=1):
def test___reduce___returns_string(self):
# Non regression test for objects with a __reduce__ method returning a
# string, meaning "save by attribute using save_global"
_cloudpickle_testpkg = pytest.importorskip("_cloudpickle_testpkg") # noqa F841
from _cloudpickle_testpkg import some_singleton
assert some_singleton.__reduce__() == "some_singleton"
depickled_singleton = pickle_depickle(
Expand Down Expand Up @@ -2187,6 +2201,7 @@ def test_pickle_dynamic_typevar_memoization(self):
assert depickled_T1 is depickled_T2

def test_pickle_importable_typevar(self):
_cloudpickle_testpkg = pytest.importorskip("_cloudpickle_testpkg") # noqa F841
from _cloudpickle_testpkg import T
T1 = pickle_depickle(T, protocol=self.protocol)
assert T1 is T
Expand Down Expand Up @@ -2512,6 +2527,7 @@ def test_pickle_constructs_from_module_registered_for_pickling_by_value(self):
def test_pickle_constructs_from_installed_packages_registered_for_pickling_by_value( # noqa
self
):
_cloudpickle_testpkg = pytest.importorskip("_cloudpickle_testpkg") # noqa F841
for package_or_module in ["package", "module"]:
if package_or_module == "package":
import _cloudpickle_testpkg as m
Expand Down Expand Up @@ -2546,7 +2562,7 @@ def test_pickle_various_versions_of_the_same_function_with_different_pickling_me
# pickled in a different way - by value and/or by reference) can
# peacefully co-exist (e.g. without globals interaction) in a remote
# worker.
import _cloudpickle_testpkg
_cloudpickle_testpkg = pytest.importorskip("_cloudpickle_testpkg") # noqa F841
from _cloudpickle_testpkg import package_function_with_global as f
_original_global = _cloudpickle_testpkg.global_variable

Expand Down Expand Up @@ -2623,7 +2639,7 @@ def test_lookup_module_and_qualname_dynamic_typevar():


def test_lookup_module_and_qualname_importable_typevar():
import _cloudpickle_testpkg
_cloudpickle_testpkg = pytest.importorskip("_cloudpickle_testpkg") # noqa F841
T = _cloudpickle_testpkg.T
module_and_name = _lookup_module_and_qualname(T, name=T.__name__)
assert module_and_name is not None
Expand All @@ -2642,6 +2658,7 @@ def test_lookup_module_and_qualname_stdlib_typevar():


def test_register_pickle_by_value():
_cloudpickle_testpkg = pytest.importorskip("_cloudpickle_testpkg") # noqa F841
import _cloudpickle_testpkg as pkg
import _cloudpickle_testpkg.mod as mod

Expand Down