Skip to content

Commit 6eef56c

Browse files
author
Claude
committed
Allow runner to override default pickler.
1 parent 5509e5b commit 6eef56c

File tree

4 files changed

+34
-4
lines changed

4 files changed

+34
-4
lines changed

sdks/python/apache_beam/pipeline.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -188,9 +188,6 @@ def __init__(
188188

189189
FileSystems.set_options(self._options)
190190

191-
pickle_library = self._options.view_as(SetupOptions).pickle_library
192-
pickler.set_library(pickle_library)
193-
194191
if runner is None:
195192
runner = self._options.view_as(StandardOptions).runner
196193
if runner is None:
@@ -225,6 +222,16 @@ def __init__(
225222

226223
# Default runner to be used.
227224
self.runner = runner
225+
226+
if (self._options.view_as(SetupOptions).pickle_library == 'default' and
227+
self.runner.default_pickle_library_override() is not None):
228+
logging.info(
229+
"Default pickling library set to : %s.",
230+
runner.default_pickle_library_override())
231+
self._options.view_as(
232+
SetupOptions).pickle_library = runner.default_pickle_library_override(
233+
)
234+
pickler.set_library(self._options.view_as(SetupOptions).pickle_library)
228235
# Stack of transforms generated by nested apply() calls. The stack will
229236
# contain a root node as an enclosing (parent) node for top transforms.
230237
self.transforms_stack = [

sdks/python/apache_beam/pipeline_test.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
from apache_beam.transforms.window import TimestampedValue
6363
from apache_beam.utils import windowed_value
6464
from apache_beam.utils.timestamp import MIN_TIMESTAMP
65+
from apache_beam.runners.direct import direct_runner
6566

6667

6768
class FakeUnboundedSource(SourceBase):
@@ -156,6 +157,20 @@ def test_create(self):
156157
pcoll3 = pcoll2 | 'do' >> FlatMap(lambda x: [x + 10])
157158
assert_that(pcoll3, equal_to([14, 15, 16]), label='pcoll3')
158159

160+
@mock.patch('logging.info')
161+
def test_runner_overrides_default_pickler(self, mock_info):
162+
with mock.patch.object(direct_runner.SwitchingDirectRunner,
163+
'default_pickle_library_override') as mock_fn:
164+
mock_fn.return_value = 'dill'
165+
with TestPipeline() as pipeline:
166+
pcoll = pipeline | 'label1' >> Create([1, 2, 3])
167+
assert_that(pcoll, equal_to([1, 2, 3]))
168+
169+
from apache_beam.internal import pickler
170+
from apache_beam.internal import dill_pickler
171+
self.assertIs(pickler.desired_pickle_lib, dill_pickler)
172+
mock_info.assert_any_call('Default pickling library set to : %s.', 'dill')
173+
159174
def test_flatmap_builtin(self):
160175
with TestPipeline() as pipeline:
161176
pcoll = pipeline | 'label1' >> Create([1, 2, 3])
@@ -279,7 +294,7 @@ def test_no_wait_until_finish(self, mock_info):
279294
with Pipeline(runner='DirectRunner',
280295
options=PipelineOptions(["--no_wait_until_finish"])) as p:
281296
_ = p | beam.Create(['test'])
282-
mock_info.assert_called_once_with(
297+
mock_info.assert_any_call(
283298
'Job execution continues without waiting for completion. '
284299
'Use "wait_until_finish" in PipelineResult to block until finished.')
285300
p.result.wait_until_finish()

sdks/python/apache_beam/runners/direct/direct_runner.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,10 @@ class SwitchingDirectRunner(PipelineRunner):
6666
which supports streaming execution and certain primitives not yet
6767
implemented in the FnApiRunner.
6868
"""
69+
def default_pickle_library_override(self):
70+
"""Default pickle library, can be overridden by runner implementation."""
71+
return 'cloudpickle'
72+
6973
def is_fnapi_compatible(self):
7074
return BundleBasedDirectRunner.is_fnapi_compatible()
7175

sdks/python/apache_beam/runners/runner.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,10 @@ def check_requirements(
225225
beam_runner_api_pb2.TimeDomain.PROCESSING_TIME):
226226
raise NotImplementedError(timer.time_domain)
227227

228+
def default_pickle_library_override(self):
229+
"""Default pickle library, can be overridden by runner implementation."""
230+
return None
231+
228232

229233
# FIXME: replace with PipelineState(str, enum.Enum)
230234
class PipelineState(object):

0 commit comments

Comments
 (0)