Skip to content

Commit bfa0c59

Browse files
authored
The Bag Partition is now configurable. (#33805)
* The Bag Partition is now configurable. Configuring the number of partitions in the Dask runner is very important to tune performance. This CL gives users control over this parameter. * Apply formatter. * Passing lint via the `run_pylint.sh` script. * Implementing review feedback. * Attempting to pass lint/fmt check. * Fixing isort issues by reading CI output. * More indentation. * rm blank like for isort.
1 parent b1d5e00 commit bfa0c59

File tree

4 files changed

+82
-7
lines changed

4 files changed

+82
-7
lines changed

CHANGES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@
8686
* Support the Process Environment for execution in Prism ([#33651](https://github.com/apache/beam/pull/33651))
8787
* Support the AnyOf Environment for execution in Prism ([#33705](https://github.com/apache/beam/pull/33705))
8888
* This improves support for developing Xlang pipelines, when using a compatible cross language service.
89+
* Partitions are now configurable for the DaskRunner in the Python SDK ([#33805](https://github.com/apache/beam/pull/33805)).
8990

9091
## Breaking Changes
9192

sdks/python/apache_beam/runners/dask/dask_runner.py

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,18 @@ def _parse_timeout(candidate):
5858
import dask
5959
return dask.config.no_default
6060

61+
@staticmethod
62+
def _extract_bag_kwargs(dask_options: t.Dict) -> t.Dict:
63+
"""Parse keyword arguments for `dask.Bag`s; used in graph translation."""
64+
out = {}
65+
66+
if npartitions := dask_options.pop('npartitions', None):
67+
out['npartitions'] = npartitions
68+
if partition_size := dask_options.pop('partition_size', None):
69+
out['partition_size'] = partition_size
70+
71+
return out
72+
6173
@classmethod
6274
def _add_argparse_args(cls, parser: argparse.ArgumentParser) -> None:
6375
parser.add_argument(
@@ -93,6 +105,21 @@ def _add_argparse_args(cls, parser: argparse.ArgumentParser) -> None:
93105
default=512,
94106
help='The number of open comms to maintain at once in the connection '
95107
'pool.')
108+
partitions_parser = parser.add_mutually_exclusive_group()
109+
partitions_parser.add_argument(
110+
'--dask_npartitions',
111+
dest='npartitions',
112+
type=int,
113+
default=None,
114+
help='The desired number of `dask.Bag` partitions. When unspecified, '
115+
'an educated guess is made.')
116+
partitions_parser.add_argument(
117+
'--dask_partition_size',
118+
dest='partition_size',
119+
type=int,
120+
default=None,
121+
help='The length of each `dask.Bag` partition. When unspecified, '
122+
'an educated guess is made.')
96123

97124

98125
@dataclasses.dataclass
@@ -139,17 +166,20 @@ def metrics(self):
139166
class DaskRunner(BundleBasedDirectRunner):
140167
"""Executes a pipeline on a Dask distributed client."""
141168
@staticmethod
142-
def to_dask_bag_visitor() -> PipelineVisitor:
169+
def to_dask_bag_visitor(bag_kwargs=None) -> PipelineVisitor:
143170
from dask import bag as db
144171

172+
if bag_kwargs is None:
173+
bag_kwargs = {}
174+
145175
@dataclasses.dataclass
146176
class DaskBagVisitor(PipelineVisitor):
147177
bags: t.Dict[AppliedPTransform, db.Bag] = dataclasses.field(
148178
default_factory=collections.OrderedDict)
149179

150180
def visit_transform(self, transform_node: AppliedPTransform) -> None:
151181
op_class = TRANSLATIONS.get(transform_node.transform.__class__, NoOp)
152-
op = op_class(transform_node)
182+
op = op_class(transform_node, bag_kwargs=bag_kwargs)
153183

154184
op_kws = {"input_bag": None, "side_inputs": None}
155185
inputs = list(transform_node.inputs)
@@ -195,7 +225,7 @@ def is_fnapi_compatible():
195225
def run_pipeline(self, pipeline, options):
196226
import dask
197227

198-
# TODO(alxr): Create interactive notebook support.
228+
# TODO(alxmrs): Create interactive notebook support.
199229
if is_in_notebook():
200230
raise NotImplementedError('interactive support will come later!')
201231

@@ -207,11 +237,12 @@ def run_pipeline(self, pipeline, options):
207237

208238
dask_options = options.view_as(DaskOptions).get_all_options(
209239
drop_default=True)
240+
bag_kwargs = DaskOptions._extract_bag_kwargs(dask_options)
210241
client = ddist.Client(**dask_options)
211242

212243
pipeline.replace_all(dask_overrides())
213244

214-
dask_visitor = self.to_dask_bag_visitor()
245+
dask_visitor = self.to_dask_bag_visitor(bag_kwargs)
215246
pipeline.visit(dask_visitor)
216247
# The dictionary in this visitor keeps a mapping of every Beam
217248
# PTransform to the equivalent Bag operation. This is highly

sdks/python/apache_beam/runners/dask/dask_runner_test.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,25 @@ def test_parser_destinations__agree_with_dask_client(self):
6666
with self.subTest(f'{opt_name} in dask.distributed.Client constructor'):
6767
self.assertIn(opt_name, client_args)
6868

69+
def test_parser_extract_bag_kwargs__deletes_dask_kwargs(self):
70+
options = PipelineOptions('--dask_npartitions 8'.split())
71+
dask_options = options.view_as(DaskOptions).get_all_options()
72+
73+
self.assertIn('npartitions', dask_options)
74+
bag_kwargs = DaskOptions._extract_bag_kwargs(dask_options)
75+
self.assertNotIn('npartitions', dask_options)
76+
self.assertEqual(bag_kwargs, {'npartitions': 8})
77+
78+
def test_parser_extract_bag_kwargs__unconfigured(self):
79+
options = PipelineOptions()
80+
dask_options = options.view_as(DaskOptions).get_all_options()
81+
82+
# It's present as a default option.
83+
self.assertIn('npartitions', dask_options)
84+
bag_kwargs = DaskOptions._extract_bag_kwargs(dask_options)
85+
self.assertNotIn('npartitions', dask_options)
86+
self.assertEqual(bag_kwargs, {})
87+
6988

7089
class DaskRunnerRunPipelineTest(unittest.TestCase):
7190
"""Test class used to introspect the dask runner via a debugger."""

sdks/python/apache_beam/runners/dask/transform_evaluator.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
"""
2323
import abc
2424
import dataclasses
25+
import logging
2526
import math
2627
import typing as t
2728
from dataclasses import field
@@ -52,6 +53,8 @@
5253
# Value types for PCollections (possibly Windowed Values).
5354
PCollVal = t.Union[WindowedValue, t.Any]
5455

56+
_LOGGER = logging.getLogger(__name__)
57+
5558

5659
def get_windowed_value(item: t.Any, window_fn: WindowFn) -> WindowedValue:
5760
"""Wraps a value (item) inside a Window."""
@@ -127,8 +130,11 @@ class DaskBagOp(abc.ABC):
127130
Attributes
128131
applied: The underlying `AppliedPTransform` which holds the code for the
129132
target operation.
133+
bag_kwargs: (optional) Keyword arguments applied to input bags, usually
134+
from the pipeline's `DaskOptions`.
130135
"""
131136
applied: AppliedPTransform
137+
bag_kwargs: t.Dict = dataclasses.field(default_factory=dict)
132138

133139
@property
134140
def transform(self):
@@ -151,10 +157,28 @@ def apply(self, input_bag: OpInput, side_inputs: OpSide = None) -> db.Bag:
151157
assert input_bag is None, 'Create expects no input!'
152158
original_transform = t.cast(_Create, self.transform)
153159
items = original_transform.values
160+
161+
npartitions = self.bag_kwargs.get('npartitions')
162+
partition_size = self.bag_kwargs.get('partition_size')
163+
if npartitions and partition_size:
164+
raise ValueError(
165+
f'Please specify either `dask_npartitions` or '
166+
f'`dask_parition_size` but not both: '
167+
f'{npartitions=}, {partition_size=}.')
168+
if not npartitions and not partition_size:
169+
# partition_size is inversely related to `npartitions`.
170+
# Ideal "chunk sizes" in dask are around 10-100 MBs.
171+
# Let's hope ~128 items per partition is around this
172+
# memory overhead.
173+
default_size = 128
174+
partition_size = max(default_size, math.ceil(math.sqrt(len(items)) / 10))
175+
if partition_size == default_size:
176+
_LOGGER.warning(
177+
'The new default partition size is %d, it used to be 1 '
178+
'in previous DaskRunner versions.' % default_size)
179+
154180
return db.from_sequence(
155-
items,
156-
partition_size=max(
157-
1, math.ceil(math.sqrt(len(items)) / math.sqrt(100))))
181+
items, npartitions=npartitions, partition_size=partition_size)
158182

159183

160184
def apply_dofn_to_bundle(

0 commit comments

Comments
 (0)