diff --git a/sdks/python/apache_beam/coders/coder_impl.py b/sdks/python/apache_beam/coders/coder_impl.py index 49cbbdd17e69..afd13baf852d 100644 --- a/sdks/python/apache_beam/coders/coder_impl.py +++ b/sdks/python/apache_beam/coders/coder_impl.py @@ -32,6 +32,7 @@ import decimal import enum +import functools import itertools import json import logging @@ -50,7 +51,6 @@ from typing import Tuple from typing import Type -import dill import numpy as np from fastavro import parse_schema from fastavro import schemaless_reader @@ -58,6 +58,7 @@ from apache_beam.coders import observable from apache_beam.coders.avro_record import AvroRecord +from apache_beam.internal import pickler from apache_beam.typehints.schemas import named_tuple_from_schema from apache_beam.utils import proto_utils from apache_beam.utils import windowed_value @@ -526,7 +527,7 @@ def _deterministic_encoding_error_msg(self, value): (value, type(value), self.requires_deterministic_step_label)) def encode_type(self, t, stream): - stream.write(dill.dumps(t), True) + stream.write(pickler.dumps(t), True) def decode_type(self, stream): return _unpickle_type(stream.read_all(True)) @@ -589,16 +590,20 @@ def decode_from_stream(self, stream, nested): _unpickled_types = {} # type: Dict[bytes, type] +def _unpickle_named_tuple_reducer(bs, self): + return (_unpickle_named_tuple, (bs, tuple(self))) + + def _unpickle_type(bs): t = _unpickled_types.get(bs, None) if t is None: - t = _unpickled_types[bs] = dill.loads(bs) + t = _unpickled_types[bs] = pickler.loads(bs) # Fix unpicklable anonymous named tuples for Python 3.6. if t.__base__ is tuple and hasattr(t, '_fields'): try: pickle.loads(pickle.dumps(t)) except pickle.PicklingError: - t.__reduce__ = lambda self: (_unpickle_named_tuple, (bs, tuple(self))) + t.__reduce__ = functools.partial(_unpickle_named_tuple_reducer, bs) return t @@ -838,6 +843,7 @@ def decode_from_stream(self, in_, nested): if IntervalWindow is None: from apache_beam.transforms.window import IntervalWindow # instantiating with None is not part of the public interface + # pylint: disable=too-many-function-args typed_value = IntervalWindow(None, None) # type: ignore[arg-type] typed_value._end_micros = ( 1000 * self._to_normal_time(in_.read_bigendian_uint64())) diff --git a/sdks/python/apache_beam/coders/coders.py b/sdks/python/apache_beam/coders/coders.py index cb23e3967e33..26dad3341c87 100644 --- a/sdks/python/apache_beam/coders/coders.py +++ b/sdks/python/apache_beam/coders/coders.py @@ -59,6 +59,8 @@ from apache_beam.coders import coder_impl from apache_beam.coders.avro_record import AvroRecord +from apache_beam.internal import cloudpickle_pickler +from apache_beam.internal import pickler from apache_beam.portability import common_urns from apache_beam.portability import python_urns from apache_beam.portability.api import beam_runner_api_pb2 @@ -77,21 +79,17 @@ # pylint: enable=wrong-import-order, wrong-import-position, ungrouped-imports # pylint: disable=wrong-import-order, wrong-import-position -# Avoid dependencies on the full SDK. try: - # Import dill from the pickler module to make sure our monkey-patching of dill - # occurs. - from apache_beam.internal.dill_pickler import dill + from apache_beam.internal import dill_pickler except ImportError: - # We fall back to using the stock dill library in tests that don't use the - # full Python SDK. - import dill + dill_pickler = None # type: ignore __all__ = [ 'Coder', 'AvroGenericCoder', 'BooleanCoder', 'BytesCoder', + 'CloudpickleCoder', 'DillCoder', 'FastPrimitivesCoder', 'FloatCoder', @@ -123,14 +121,12 @@ def serialize_coder(coder): - from apache_beam.internal import pickler return b'%s$%s' % ( coder.__class__.__name__.encode('utf-8'), pickler.dumps(coder, use_zlib=True)) def deserialize_coder(serialized): - from apache_beam.internal import pickler return pickler.loads(serialized.split(b'$', 1)[1], use_zlib=True) @@ -812,7 +808,7 @@ def maybe_dill_dumps(o): try: return pickle.dumps(o, pickle.HIGHEST_PROTOCOL) except Exception: # pylint: disable=broad-except - return dill.dumps(o) + return dill_pickler.dumps(o) def maybe_dill_loads(o): @@ -820,7 +816,7 @@ def maybe_dill_loads(o): try: return pickle.loads(o) except Exception: # pylint: disable=broad-except - return dill.loads(o) + return dill_pickler.loads(o) class _PickleCoderBase(FastCoder): @@ -860,7 +856,6 @@ def __init__(self, cache_size=16): self.cache_size = cache_size def _create_impl(self): - from apache_beam.internal import pickler dumps = pickler.dumps mdumps = lru_cache(maxsize=self.cache_size, typed=True)(dumps) @@ -896,11 +891,24 @@ def to_type_hint(self): class DillCoder(_PickleCoderBase): - """Coder using dill's pickle functionality.""" + def __init__(self): + """Coder using dill's pickle functionality.""" + if dill_pickler is None: + raise ImportError( + "Dill is not installed. To use DillCoder, please install " + "apache-beam[dill] or install dill separately.") + def _create_impl(self): return coder_impl.CallbackCoderImpl(maybe_dill_dumps, maybe_dill_loads) +class CloudpickleCoder(_PickleCoderBase): + """Coder using Apache Beam's vendored Cloudpickle pickler.""" + def _create_impl(self): + return coder_impl.CallbackCoderImpl( + cloudpickle_pickler.dumps, cloudpickle_pickler.loads) + + class DeterministicFastPrimitivesCoder(FastCoder): """Throws runtime errors when encoding non-deterministic values.""" def __init__(self, coder, step_label): diff --git a/sdks/python/apache_beam/coders/coders_test_common.py b/sdks/python/apache_beam/coders/coders_test_common.py index bed93cbc5545..8c0c483ac231 100644 --- a/sdks/python/apache_beam/coders/coders_test_common.py +++ b/sdks/python/apache_beam/coders/coders_test_common.py @@ -30,6 +30,8 @@ from typing import NamedTuple import pytest +from parameterized import param +from parameterized import parameterized from apache_beam.coders import proto2_coder_test_messages_pb2 as test_message from apache_beam.coders import coders @@ -219,7 +221,12 @@ def test_memoizing_pickle_coder(self): coder = coders._MemoizingPickleCoder() self.check_coder(coder, *self.test_values) - def test_deterministic_coder(self): + @parameterized.expand([ + param(pickle_lib='dill'), + param(pickle_lib='cloudpickle'), + ]) + def test_deterministic_coder(self, pickle_lib): + pickler.set_library(pickle_lib) coder = coders.FastPrimitivesCoder() deterministic_coder = coders.DeterministicFastPrimitivesCoder(coder, 'step') self.check_coder(deterministic_coder, *self.test_values_deterministic) @@ -282,6 +289,13 @@ def test_dill_coder(self): coders.TupleCoder((coders.VarIntCoder(), coders.DillCoder())), (1, cell_value)) + def test_cloudpickle_pickle_coder(self): + cell_value = (lambda x: lambda: x)(0).__closure__[0] + self.check_coder(coders.CloudpickleCoder(), 'a', 1, cell_value) + self.check_coder( + coders.TupleCoder((coders.VarIntCoder(), coders.CloudpickleCoder())), + (1, cell_value)) + def test_fast_primitives_coder(self): coder = coders.FastPrimitivesCoder(coders.SingletonCoder(len)) self.check_coder(coder, *self.test_values) @@ -539,6 +553,7 @@ def test_windowed_value_coder(self): def test_param_windowed_value_coder(self): from apache_beam.transforms.window import IntervalWindow from apache_beam.utils.windowed_value import PaneInfo + # pylint: disable=too-many-function-args wv = windowed_value.create( b'', # Milliseconds to microseconds diff --git a/sdks/python/apache_beam/internal/pickler.py b/sdks/python/apache_beam/internal/pickler.py index 256f88c5453f..9cdb2581e739 100644 --- a/sdks/python/apache_beam/internal/pickler.py +++ b/sdks/python/apache_beam/internal/pickler.py @@ -29,7 +29,12 @@ """ from apache_beam.internal import cloudpickle_pickler -from apache_beam.internal import dill_pickler + +try: + from apache_beam.internal import dill_pickler + DILL_AVAILABLE = True +except ImportError: + DILL_AVAILABLE = False USE_CLOUDPICKLE = 'cloudpickle' USE_DILL = 'dill' @@ -43,7 +48,6 @@ def dumps( enable_trace=True, use_zlib=False, enable_best_effort_determinism=False) -> bytes: - return desired_pickle_lib.dumps( o, enable_trace=enable_trace, @@ -73,9 +77,16 @@ def load_session(file_path): def set_library(selected_library=DEFAULT_PICKLE_LIB): """ Sets pickle library that will be used. """ + if selected_library == USE_DILL and not DILL_AVAILABLE: + if not DILL_AVAILABLE: + raise ImportError( + "Dill is not installed. To use DillCoder, please install " + "apache-beam[dill] or install dill separately.") + global desired_pickle_lib # If switching to or from dill, update the pickler hook overrides. - if (selected_library == USE_DILL) != (desired_pickle_lib == dill_pickler): + if DILL_AVAILABLE and ((selected_library == USE_DILL) != + (desired_pickle_lib == dill_pickler)): dill_pickler.override_pickler_hooks(selected_library == USE_DILL) if selected_library == 'default': diff --git a/sdks/python/apache_beam/ml/anomaly/transforms.py b/sdks/python/apache_beam/ml/anomaly/transforms.py index d704f93ed618..87b48022e1fe 100644 --- a/sdks/python/apache_beam/ml/anomaly/transforms.py +++ b/sdks/python/apache_beam/ml/anomaly/transforms.py @@ -24,7 +24,7 @@ from typing import TypeVar import apache_beam as beam -from apache_beam.coders import DillCoder +from apache_beam.coders import CloudpickleCoder from apache_beam.ml.anomaly import aggregations from apache_beam.ml.anomaly.base import AggregationFn from apache_beam.ml.anomaly.base import AnomalyDetector @@ -53,7 +53,8 @@ class _ScoreAndLearnDoFn(beam.DoFn): then updates the model with the same data. It maintains the model state using Beam's state management. """ - MODEL_STATE_INDEX = ReadModifyWriteStateSpec('saved_model', DillCoder()) + MODEL_STATE_INDEX = ReadModifyWriteStateSpec( + 'saved_model', CloudpickleCoder()) def __init__(self, detector_spec: Spec): self._detector_spec = detector_spec @@ -222,7 +223,8 @@ class _StatefulThresholdDoFn(_BaseThresholdDoFn): AssertionError: If the provided `threshold_fn_spec` leads to the creation of a stateless `ThresholdFn`. """ - THRESHOLD_STATE_INDEX = ReadModifyWriteStateSpec('saved_tracker', DillCoder()) + THRESHOLD_STATE_INDEX = ReadModifyWriteStateSpec( + 'saved_tracker', CloudpickleCoder()) def __init__(self, threshold_fn_spec: Spec): assert isinstance(threshold_fn_spec.config, dict) diff --git a/sdks/python/setup.py b/sdks/python/setup.py index 2b21d0463c98..0f6f5b380800 100644 --- a/sdks/python/setup.py +++ b/sdks/python/setup.py @@ -341,12 +341,6 @@ def get_portability_package_data(): install_requires=[ 'crcmod>=1.7,<2.0', 'orjson>=3.9.7,<4', - # Dill doesn't have forwards-compatibility guarantees within minor - # version. Pickles created with a new version of dill may not unpickle - # using older version of dill. It is best to use the same version of - # dill on client and server, therefore list of allowed versions is - # very narrow. See: https://github.com/uqfoundation/dill/issues/341. - 'dill>=0.3.1.1,<0.3.2', 'fastavro>=0.23.6,<2', 'fasteners>=0.3,<1.0', # TODO(https://github.com/grpc/grpc/issues/37710): Unpin grpc @@ -390,6 +384,15 @@ def get_portability_package_data(): python_requires=python_requires, # BEAM-8840: Do NOT use tests_require or setup_requires. extras_require={ + 'dill' : [ + # Dill doesn't have forwards-compatibility guarantees within minor + # version. Pickles created with a new version of dill may not + # unpickle using older version of dill. It is best to use the same + # version of dill on client and server, therefore list of allowed + # versions is very narrow. + # See: https://github.com/uqfoundation/dill/issues/341. + 'dill>=0.3.1.1,<0.3.2', + ], 'docs': [ 'jinja2>=3.0,<3.2', 'Sphinx>=7.0.0,<8.0', @@ -401,6 +404,7 @@ def get_portability_package_data(): 'virtualenv-clone>=0.5,<1.0', ], 'test': [ + 'dill>=0.3.1.1,<0.3.2', 'docstring-parser>=0.15,<1.0', 'freezegun>=0.3.12', 'jinja2>=3.0,<3.2',