Skip to content

Commit 544ce26

Browse files
claudevdmClaude
authored andcommitted
Make dill optional and use selected pickle library for coder types.
1 parent 9fd7d09 commit 544ce26

File tree

6 files changed

+76
-30
lines changed

6 files changed

+76
-30
lines changed

sdks/python/apache_beam/coders/coder_impl.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232

3333
import decimal
3434
import enum
35+
import functools
3536
import itertools
3637
import json
3738
import logging
@@ -50,14 +51,14 @@
5051
from typing import Tuple
5152
from typing import Type
5253

53-
import dill
5454
import numpy as np
5555
from fastavro import parse_schema
5656
from fastavro import schemaless_reader
5757
from fastavro import schemaless_writer
5858

5959
from apache_beam.coders import observable
6060
from apache_beam.coders.avro_record import AvroRecord
61+
from apache_beam.internal import pickler
6162
from apache_beam.typehints.schemas import named_tuple_from_schema
6263
from apache_beam.utils import proto_utils
6364
from apache_beam.utils import windowed_value
@@ -526,7 +527,7 @@ def _deterministic_encoding_error_msg(self, value):
526527
(value, type(value), self.requires_deterministic_step_label))
527528

528529
def encode_type(self, t, stream):
529-
stream.write(dill.dumps(t), True)
530+
stream.write(pickler.dumps(t), True)
530531

531532
def decode_type(self, stream):
532533
return _unpickle_type(stream.read_all(True))
@@ -589,16 +590,20 @@ def decode_from_stream(self, stream, nested):
589590
_unpickled_types = {} # type: Dict[bytes, type]
590591

591592

593+
def _unpickle_named_tuple_reducer(bs, self):
594+
return (_unpickle_named_tuple, (bs, tuple(self)))
595+
596+
592597
def _unpickle_type(bs):
593598
t = _unpickled_types.get(bs, None)
594599
if t is None:
595-
t = _unpickled_types[bs] = dill.loads(bs)
600+
t = _unpickled_types[bs] = pickler.loads(bs)
596601
# Fix unpicklable anonymous named tuples for Python 3.6.
597602
if t.__base__ is tuple and hasattr(t, '_fields'):
598603
try:
599604
pickle.loads(pickle.dumps(t))
600605
except pickle.PicklingError:
601-
t.__reduce__ = lambda self: (_unpickle_named_tuple, (bs, tuple(self)))
606+
t.__reduce__ = functools.partial(_unpickle_named_tuple_reducer, bs)
602607
return t
603608

604609

@@ -838,6 +843,7 @@ def decode_from_stream(self, in_, nested):
838843
if IntervalWindow is None:
839844
from apache_beam.transforms.window import IntervalWindow
840845
# instantiating with None is not part of the public interface
846+
# pylint: disable=too-many-function-args
841847
typed_value = IntervalWindow(None, None) # type: ignore[arg-type]
842848
typed_value._end_micros = (
843849
1000 * self._to_normal_time(in_.read_bigendian_uint64()))

sdks/python/apache_beam/coders/coders.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@
5959

6060
from apache_beam.coders import coder_impl
6161
from apache_beam.coders.avro_record import AvroRecord
62+
from apache_beam.internal import cloudpickle_pickler
63+
from apache_beam.internal import pickler
6264
from apache_beam.portability import common_urns
6365
from apache_beam.portability import python_urns
6466
from apache_beam.portability.api import beam_runner_api_pb2
@@ -77,21 +79,17 @@
7779
# pylint: enable=wrong-import-order, wrong-import-position, ungrouped-imports
7880

7981
# pylint: disable=wrong-import-order, wrong-import-position
80-
# Avoid dependencies on the full SDK.
8182
try:
82-
# Import dill from the pickler module to make sure our monkey-patching of dill
83-
# occurs.
84-
from apache_beam.internal.dill_pickler import dill
83+
from apache_beam.internal import dill_pickler
8584
except ImportError:
86-
# We fall back to using the stock dill library in tests that don't use the
87-
# full Python SDK.
88-
import dill
85+
dill_pickler = None # type: ignore
8986

9087
__all__ = [
9188
'Coder',
9289
'AvroGenericCoder',
9390
'BooleanCoder',
9491
'BytesCoder',
92+
'CloudpickleCoder',
9593
'DillCoder',
9694
'FastPrimitivesCoder',
9795
'FloatCoder',
@@ -123,14 +121,12 @@
123121

124122

125123
def serialize_coder(coder):
126-
from apache_beam.internal import pickler
127124
return b'%s$%s' % (
128125
coder.__class__.__name__.encode('utf-8'),
129126
pickler.dumps(coder, use_zlib=True))
130127

131128

132129
def deserialize_coder(serialized):
133-
from apache_beam.internal import pickler
134130
return pickler.loads(serialized.split(b'$', 1)[1], use_zlib=True)
135131

136132

@@ -812,15 +808,15 @@ def maybe_dill_dumps(o):
812808
try:
813809
return pickle.dumps(o, pickle.HIGHEST_PROTOCOL)
814810
except Exception: # pylint: disable=broad-except
815-
return dill.dumps(o)
811+
return dill_pickler.dumps(o)
816812

817813

818814
def maybe_dill_loads(o):
819815
"""Unpickle using cPickle or the Dill pickler as a fallback."""
820816
try:
821817
return pickle.loads(o)
822818
except Exception: # pylint: disable=broad-except
823-
return dill.loads(o)
819+
return dill_pickler.loads(o)
824820

825821

826822
class _PickleCoderBase(FastCoder):
@@ -860,7 +856,6 @@ def __init__(self, cache_size=16):
860856
self.cache_size = cache_size
861857

862858
def _create_impl(self):
863-
from apache_beam.internal import pickler
864859
dumps = pickler.dumps
865860

866861
mdumps = lru_cache(maxsize=self.cache_size, typed=True)(dumps)
@@ -896,11 +891,24 @@ def to_type_hint(self):
896891

897892

898893
class DillCoder(_PickleCoderBase):
899-
"""Coder using dill's pickle functionality."""
894+
def __init__(self):
895+
"""Coder using dill's pickle functionality."""
896+
if dill_pickler is None:
897+
raise ImportError(
898+
"Dill is not installed. To use DillCoder, please install "
899+
"apache-beam[dill] or install dill separately.")
900+
900901
def _create_impl(self):
901902
return coder_impl.CallbackCoderImpl(maybe_dill_dumps, maybe_dill_loads)
902903

903904

905+
class CloudpickleCoder(_PickleCoderBase):
906+
"""Coder using Apache Beam's vendored Cloudpickle pickler."""
907+
def _create_impl(self):
908+
return coder_impl.CallbackCoderImpl(
909+
cloudpickle_pickler.dumps, cloudpickle_pickler.loads)
910+
911+
904912
class DeterministicFastPrimitivesCoder(FastCoder):
905913
"""Throws runtime errors when encoding non-deterministic values."""
906914
def __init__(self, coder, step_label):

sdks/python/apache_beam/coders/coders_test_common.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030
from typing import NamedTuple
3131

3232
import pytest
33+
from parameterized import param
34+
from parameterized import parameterized
3335

3436
from apache_beam.coders import proto2_coder_test_messages_pb2 as test_message
3537
from apache_beam.coders import coders
@@ -219,7 +221,12 @@ def test_memoizing_pickle_coder(self):
219221
coder = coders._MemoizingPickleCoder()
220222
self.check_coder(coder, *self.test_values)
221223

222-
def test_deterministic_coder(self):
224+
@parameterized.expand([
225+
param(pickle_lib='dill'),
226+
param(pickle_lib='cloudpickle'),
227+
])
228+
def test_deterministic_coder(self, pickle_lib):
229+
pickler.set_library(pickle_lib)
223230
coder = coders.FastPrimitivesCoder()
224231
deterministic_coder = coders.DeterministicFastPrimitivesCoder(coder, 'step')
225232
self.check_coder(deterministic_coder, *self.test_values_deterministic)
@@ -282,6 +289,13 @@ def test_dill_coder(self):
282289
coders.TupleCoder((coders.VarIntCoder(), coders.DillCoder())),
283290
(1, cell_value))
284291

292+
def test_cloudpickle_pickle_coder(self):
293+
cell_value = (lambda x: lambda: x)(0).__closure__[0]
294+
self.check_coder(coders.CloudpickleCoder(), 'a', 1, cell_value)
295+
self.check_coder(
296+
coders.TupleCoder((coders.VarIntCoder(), coders.CloudpickleCoder())),
297+
(1, cell_value))
298+
285299
def test_fast_primitives_coder(self):
286300
coder = coders.FastPrimitivesCoder(coders.SingletonCoder(len))
287301
self.check_coder(coder, *self.test_values)
@@ -539,6 +553,7 @@ def test_windowed_value_coder(self):
539553
def test_param_windowed_value_coder(self):
540554
from apache_beam.transforms.window import IntervalWindow
541555
from apache_beam.utils.windowed_value import PaneInfo
556+
# pylint: disable=too-many-function-args
542557
wv = windowed_value.create(
543558
b'',
544559
# Milliseconds to microseconds

sdks/python/apache_beam/internal/pickler.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,12 @@
2929
"""
3030

3131
from apache_beam.internal import cloudpickle_pickler
32-
from apache_beam.internal import dill_pickler
32+
33+
try:
34+
from apache_beam.internal import dill_pickler
35+
DILL_AVAILABLE = True
36+
except ImportError:
37+
DILL_AVAILABLE = False
3338

3439
USE_CLOUDPICKLE = 'cloudpickle'
3540
USE_DILL = 'dill'
@@ -43,7 +48,6 @@ def dumps(
4348
enable_trace=True,
4449
use_zlib=False,
4550
enable_best_effort_determinism=False) -> bytes:
46-
4751
return desired_pickle_lib.dumps(
4852
o,
4953
enable_trace=enable_trace,
@@ -73,9 +77,16 @@ def load_session(file_path):
7377

7478
def set_library(selected_library=DEFAULT_PICKLE_LIB):
7579
""" Sets pickle library that will be used. """
80+
if selected_library == USE_DILL and not DILL_AVAILABLE:
81+
if not DILL_AVAILABLE:
82+
raise ImportError(
83+
"Dill is not installed. To use DillCoder, please install "
84+
"apache-beam[dill] or install dill separately.")
85+
7686
global desired_pickle_lib
7787
# If switching to or from dill, update the pickler hook overrides.
78-
if (selected_library == USE_DILL) != (desired_pickle_lib == dill_pickler):
88+
if DILL_AVAILABLE and ((selected_library == USE_DILL) !=
89+
(desired_pickle_lib == dill_pickler)):
7990
dill_pickler.override_pickler_hooks(selected_library == USE_DILL)
8091

8192
if selected_library == 'default':

sdks/python/apache_beam/ml/anomaly/transforms.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from typing import TypeVar
2525

2626
import apache_beam as beam
27-
from apache_beam.coders import DillCoder
27+
from apache_beam.coders import InternalPickleCoder
2828
from apache_beam.ml.anomaly import aggregations
2929
from apache_beam.ml.anomaly.base import AggregationFn
3030
from apache_beam.ml.anomaly.base import AnomalyDetector
@@ -53,7 +53,8 @@ class _ScoreAndLearnDoFn(beam.DoFn):
5353
then updates the model with the same data. It maintains the model state
5454
using Beam's state management.
5555
"""
56-
MODEL_STATE_INDEX = ReadModifyWriteStateSpec('saved_model', DillCoder())
56+
MODEL_STATE_INDEX = ReadModifyWriteStateSpec(
57+
'saved_model', InternalPickleCoder())
5758

5859
def __init__(self, detector_spec: Spec):
5960
self._detector_spec = detector_spec
@@ -222,7 +223,8 @@ class _StatefulThresholdDoFn(_BaseThresholdDoFn):
222223
AssertionError: If the provided `threshold_fn_spec` leads to the
223224
creation of a stateless `ThresholdFn`.
224225
"""
225-
THRESHOLD_STATE_INDEX = ReadModifyWriteStateSpec('saved_tracker', DillCoder())
226+
THRESHOLD_STATE_INDEX = ReadModifyWriteStateSpec(
227+
'saved_tracker', InternalPickleCoder())
226228

227229
def __init__(self, threshold_fn_spec: Spec):
228230
assert isinstance(threshold_fn_spec.config, dict)

sdks/python/setup.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -341,12 +341,6 @@ def get_portability_package_data():
341341
install_requires=[
342342
'crcmod>=1.7,<2.0',
343343
'orjson>=3.9.7,<4',
344-
# Dill doesn't have forwards-compatibility guarantees within minor
345-
# version. Pickles created with a new version of dill may not unpickle
346-
# using older version of dill. It is best to use the same version of
347-
# dill on client and server, therefore list of allowed versions is
348-
# very narrow. See: https://github.com/uqfoundation/dill/issues/341.
349-
'dill>=0.3.1.1,<0.3.2',
350344
'fastavro>=0.23.6,<2',
351345
'fasteners>=0.3,<1.0',
352346
# TODO(https://github.com/grpc/grpc/issues/37710): Unpin grpc
@@ -390,6 +384,15 @@ def get_portability_package_data():
390384
python_requires=python_requires,
391385
# BEAM-8840: Do NOT use tests_require or setup_requires.
392386
extras_require={
387+
'dill' : [
388+
# Dill doesn't have forwards-compatibility guarantees within minor
389+
# version. Pickles created with a new version of dill may not
390+
# unpickle using older version of dill. It is best to use the same
391+
# version of dill on client and server, therefore list of allowed
392+
# versions is very narrow.
393+
# See: https://github.com/uqfoundation/dill/issues/341.
394+
'dill>=0.3.1.1,<0.3.2',
395+
],
393396
'docs': [
394397
'jinja2>=3.0,<3.2',
395398
'Sphinx>=7.0.0,<8.0',
@@ -401,6 +404,7 @@ def get_portability_package_data():
401404
'virtualenv-clone>=0.5,<1.0',
402405
],
403406
'test': [
407+
'dill>=0.3.1.1,<0.3.2',
404408
'docstring-parser>=0.15,<1.0',
405409
'freezegun>=0.3.12',
406410
'jinja2>=3.0,<3.2',

0 commit comments

Comments
 (0)