Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions sdks/python/apache_beam/coders/coder_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

import decimal
import enum
import functools
import itertools
import json
import logging
Expand All @@ -50,14 +51,14 @@
from typing import Tuple
from typing import Type

import dill
import numpy as np
from fastavro import parse_schema
from fastavro import schemaless_reader
from fastavro import schemaless_writer

from apache_beam.coders import observable
from apache_beam.coders.avro_record import AvroRecord
from apache_beam.internal import pickler
from apache_beam.typehints.schemas import named_tuple_from_schema
from apache_beam.utils import proto_utils
from apache_beam.utils import windowed_value
Expand Down Expand Up @@ -526,7 +527,7 @@ def _deterministic_encoding_error_msg(self, value):
(value, type(value), self.requires_deterministic_step_label))

def encode_type(self, t, stream):
stream.write(dill.dumps(t), True)
stream.write(pickler.dumps(t), True)

def decode_type(self, stream):
return _unpickle_type(stream.read_all(True))
Expand Down Expand Up @@ -589,16 +590,20 @@ def decode_from_stream(self, stream, nested):
_unpickled_types = {} # type: Dict[bytes, type]


def _unpickle_named_tuple_reducer(bs, self):
return (_unpickle_named_tuple, (bs, tuple(self)))


def _unpickle_type(bs):
t = _unpickled_types.get(bs, None)
if t is None:
t = _unpickled_types[bs] = dill.loads(bs)
t = _unpickled_types[bs] = pickler.loads(bs)
# Fix unpicklable anonymous named tuples for Python 3.6.
if t.__base__ is tuple and hasattr(t, '_fields'):
try:
pickle.loads(pickle.dumps(t))
except pickle.PicklingError:
t.__reduce__ = lambda self: (_unpickle_named_tuple, (bs, tuple(self)))
t.__reduce__ = functools.partial(_unpickle_named_tuple_reducer, bs)
return t


Expand Down Expand Up @@ -838,6 +843,7 @@ def decode_from_stream(self, in_, nested):
if IntervalWindow is None:
from apache_beam.transforms.window import IntervalWindow
# instantiating with None is not part of the public interface
# pylint: disable=too-many-function-args
typed_value = IntervalWindow(None, None) # type: ignore[arg-type]
typed_value._end_micros = (
1000 * self._to_normal_time(in_.read_bigendian_uint64()))
Expand Down
34 changes: 21 additions & 13 deletions sdks/python/apache_beam/coders/coders.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@

from apache_beam.coders import coder_impl
from apache_beam.coders.avro_record import AvroRecord
from apache_beam.internal import cloudpickle_pickler
from apache_beam.internal import pickler
from apache_beam.portability import common_urns
from apache_beam.portability import python_urns
from apache_beam.portability.api import beam_runner_api_pb2
Expand All @@ -77,21 +79,17 @@
# pylint: enable=wrong-import-order, wrong-import-position, ungrouped-imports

# pylint: disable=wrong-import-order, wrong-import-position
# Avoid dependencies on the full SDK.
try:
# Import dill from the pickler module to make sure our monkey-patching of dill
# occurs.
from apache_beam.internal.dill_pickler import dill
from apache_beam.internal import dill_pickler
except ImportError:
# We fall back to using the stock dill library in tests that don't use the
# full Python SDK.
import dill
dill_pickler = None # type: ignore

__all__ = [
'Coder',
'AvroGenericCoder',
'BooleanCoder',
'BytesCoder',
'CloudpickleCoder',
'DillCoder',
'FastPrimitivesCoder',
'FloatCoder',
Expand Down Expand Up @@ -123,14 +121,12 @@


def serialize_coder(coder):
from apache_beam.internal import pickler
return b'%s$%s' % (
coder.__class__.__name__.encode('utf-8'),
pickler.dumps(coder, use_zlib=True))


def deserialize_coder(serialized):
from apache_beam.internal import pickler
return pickler.loads(serialized.split(b'$', 1)[1], use_zlib=True)


Expand Down Expand Up @@ -812,15 +808,15 @@ def maybe_dill_dumps(o):
try:
return pickle.dumps(o, pickle.HIGHEST_PROTOCOL)
except Exception: # pylint: disable=broad-except
return dill.dumps(o)
return dill_pickler.dumps(o)


def maybe_dill_loads(o):
"""Unpickle using cPickle or the Dill pickler as a fallback."""
try:
return pickle.loads(o)
except Exception: # pylint: disable=broad-except
return dill.loads(o)
return dill_pickler.loads(o)


class _PickleCoderBase(FastCoder):
Expand Down Expand Up @@ -860,7 +856,6 @@ def __init__(self, cache_size=16):
self.cache_size = cache_size

def _create_impl(self):
from apache_beam.internal import pickler
dumps = pickler.dumps

mdumps = lru_cache(maxsize=self.cache_size, typed=True)(dumps)
Expand Down Expand Up @@ -896,11 +891,24 @@ def to_type_hint(self):


class DillCoder(_PickleCoderBase):
"""Coder using dill's pickle functionality."""
def __init__(self):
"""Coder using dill's pickle functionality."""
if dill_pickler is None:
raise ImportError(
"Dill is not installed. To use DillCoder, please install "
"apache-beam[dill] or install dill separately.")

def _create_impl(self):
return coder_impl.CallbackCoderImpl(maybe_dill_dumps, maybe_dill_loads)


class CloudpickleCoder(_PickleCoderBase):
"""Coder using Apache Beam's vendored Cloudpickle pickler."""
def _create_impl(self):
return coder_impl.CallbackCoderImpl(
cloudpickle_pickler.dumps, cloudpickle_pickler.loads)


class DeterministicFastPrimitivesCoder(FastCoder):
"""Throws runtime errors when encoding non-deterministic values."""
def __init__(self, coder, step_label):
Expand Down
17 changes: 16 additions & 1 deletion sdks/python/apache_beam/coders/coders_test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
from typing import NamedTuple

import pytest
from parameterized import param
from parameterized import parameterized

from apache_beam.coders import proto2_coder_test_messages_pb2 as test_message
from apache_beam.coders import coders
Expand Down Expand Up @@ -219,7 +221,12 @@ def test_memoizing_pickle_coder(self):
coder = coders._MemoizingPickleCoder()
self.check_coder(coder, *self.test_values)

def test_deterministic_coder(self):
@parameterized.expand([
param(pickle_lib='dill'),
param(pickle_lib='cloudpickle'),
])
def test_deterministic_coder(self, pickle_lib):
pickler.set_library(pickle_lib)
coder = coders.FastPrimitivesCoder()
deterministic_coder = coders.DeterministicFastPrimitivesCoder(coder, 'step')
self.check_coder(deterministic_coder, *self.test_values_deterministic)
Expand Down Expand Up @@ -282,6 +289,13 @@ def test_dill_coder(self):
coders.TupleCoder((coders.VarIntCoder(), coders.DillCoder())),
(1, cell_value))

def test_cloudpickle_pickle_coder(self):
cell_value = (lambda x: lambda: x)(0).__closure__[0]
self.check_coder(coders.CloudpickleCoder(), 'a', 1, cell_value)
self.check_coder(
coders.TupleCoder((coders.VarIntCoder(), coders.CloudpickleCoder())),
(1, cell_value))

def test_fast_primitives_coder(self):
coder = coders.FastPrimitivesCoder(coders.SingletonCoder(len))
self.check_coder(coder, *self.test_values)
Expand Down Expand Up @@ -539,6 +553,7 @@ def test_windowed_value_coder(self):
def test_param_windowed_value_coder(self):
from apache_beam.transforms.window import IntervalWindow
from apache_beam.utils.windowed_value import PaneInfo
# pylint: disable=too-many-function-args
wv = windowed_value.create(
b'',
# Milliseconds to microseconds
Expand Down
17 changes: 14 additions & 3 deletions sdks/python/apache_beam/internal/pickler.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,12 @@
"""

from apache_beam.internal import cloudpickle_pickler
from apache_beam.internal import dill_pickler

try:
from apache_beam.internal import dill_pickler
DILL_AVAILABLE = True
except ImportError:
DILL_AVAILABLE = False

USE_CLOUDPICKLE = 'cloudpickle'
USE_DILL = 'dill'
Expand All @@ -43,7 +48,6 @@ def dumps(
enable_trace=True,
use_zlib=False,
enable_best_effort_determinism=False) -> bytes:

return desired_pickle_lib.dumps(
o,
enable_trace=enable_trace,
Expand Down Expand Up @@ -73,9 +77,16 @@ def load_session(file_path):

def set_library(selected_library=DEFAULT_PICKLE_LIB):
""" Sets pickle library that will be used. """
if selected_library == USE_DILL and not DILL_AVAILABLE:
if not DILL_AVAILABLE:
raise ImportError(
"Dill is not installed. To use DillCoder, please install "
"apache-beam[dill] or install dill separately.")

global desired_pickle_lib
# If switching to or from dill, update the pickler hook overrides.
if (selected_library == USE_DILL) != (desired_pickle_lib == dill_pickler):
if DILL_AVAILABLE and ((selected_library == USE_DILL) !=
(desired_pickle_lib == dill_pickler)):
dill_pickler.override_pickler_hooks(selected_library == USE_DILL)

if selected_library == 'default':
Expand Down
8 changes: 5 additions & 3 deletions sdks/python/apache_beam/ml/anomaly/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from typing import TypeVar

import apache_beam as beam
from apache_beam.coders import DillCoder
from apache_beam.coders import CloudpickleCoder
from apache_beam.ml.anomaly import aggregations
from apache_beam.ml.anomaly.base import AggregationFn
from apache_beam.ml.anomaly.base import AnomalyDetector
Expand Down Expand Up @@ -53,7 +53,8 @@ class _ScoreAndLearnDoFn(beam.DoFn):
then updates the model with the same data. It maintains the model state
using Beam's state management.
"""
MODEL_STATE_INDEX = ReadModifyWriteStateSpec('saved_model', DillCoder())
MODEL_STATE_INDEX = ReadModifyWriteStateSpec(
'saved_model', CloudpickleCoder())

def __init__(self, detector_spec: Spec):
self._detector_spec = detector_spec
Expand Down Expand Up @@ -222,7 +223,8 @@ class _StatefulThresholdDoFn(_BaseThresholdDoFn):
AssertionError: If the provided `threshold_fn_spec` leads to the
creation of a stateless `ThresholdFn`.
"""
THRESHOLD_STATE_INDEX = ReadModifyWriteStateSpec('saved_tracker', DillCoder())
THRESHOLD_STATE_INDEX = ReadModifyWriteStateSpec(
'saved_tracker', CloudpickleCoder())

def __init__(self, threshold_fn_spec: Spec):
assert isinstance(threshold_fn_spec.config, dict)
Expand Down
16 changes: 10 additions & 6 deletions sdks/python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,12 +341,6 @@ def get_portability_package_data():
install_requires=[
'crcmod>=1.7,<2.0',
'orjson>=3.9.7,<4',
# Dill doesn't have forwards-compatibility guarantees within minor
# version. Pickles created with a new version of dill may not unpickle
# using older version of dill. It is best to use the same version of
# dill on client and server, therefore list of allowed versions is
# very narrow. See: https://github.com/uqfoundation/dill/issues/341.
'dill>=0.3.1.1,<0.3.2',
'fastavro>=0.23.6,<2',
'fasteners>=0.3,<1.0',
# TODO(https://github.com/grpc/grpc/issues/37710): Unpin grpc
Expand Down Expand Up @@ -390,6 +384,15 @@ def get_portability_package_data():
python_requires=python_requires,
# BEAM-8840: Do NOT use tests_require or setup_requires.
extras_require={
'dill' : [
# Dill doesn't have forwards-compatibility guarantees within minor
# version. Pickles created with a new version of dill may not
# unpickle using older version of dill. It is best to use the same
# version of dill on client and server, therefore list of allowed
# versions is very narrow.
# See: https://github.com/uqfoundation/dill/issues/341.
'dill>=0.3.1.1,<0.3.2',
],
'docs': [
'jinja2>=3.0,<3.2',
'Sphinx>=7.0.0,<8.0',
Expand All @@ -401,6 +404,7 @@ def get_portability_package_data():
'virtualenv-clone>=0.5,<1.0',
],
'test': [
'dill>=0.3.1.1,<0.3.2',
'docstring-parser>=0.15,<1.0',
'freezegun>=0.3.12',
'jinja2>=3.0,<3.2',
Expand Down
Loading