Skip to content

Commit 057f297

Browse files
committed
Pass options in DaskOptions inheritance hierarchy only for Dask runner
1 parent 31a96b6 commit 057f297

File tree

3 files changed

+44
-8
lines changed

3 files changed

+44
-8
lines changed

sdks/python/apache_beam/options/pipeline_options.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -486,11 +486,12 @@ def get_all_options(
486486
drop_default=False,
487487
add_extra_args_fn: Optional[Callable[[_BeamArgumentParser], None]] = None,
488488
retain_unknown_options=False,
489-
display_warnings=False) -> Dict[str, Any]:
489+
display_warnings=False,
490+
hierarchy_only=False,
491+
) -> Dict[str, Any]:
490492
"""Returns a dictionary of all defined arguments.
491493
492-
Returns a dictionary of all defined arguments (arguments that are defined in
493-
any subclass of PipelineOptions) into a dictionary.
494+
Returns a dictionary of all defined arguments into a dictionary.
494495
495496
Args:
496497
drop_default: If set to true, options that are equal to their default
@@ -500,6 +501,9 @@ def get_all_options(
500501
retain_unknown_options: If set to true, options not recognized by any
501502
known pipeline options class will still be included in the result. If
502503
set to false, they will be discarded.
504+
hierarchy_only: If set to true, only returns options defined in this class
505+
and its super classes only. Otherwise, arguments that are defined in
506+
any subclass of PipelineOptions are returned (default).
503507
504508
Returns:
505509
Dictionary of all args and values.
@@ -510,8 +514,13 @@ def get_all_options(
510514
# instance of each subclass to avoid conflicts.
511515
subset = {}
512516
parser = _BeamArgumentParser(allow_abbrev=False)
513-
for cls in PipelineOptions.__subclasses__():
514-
subset.setdefault(str(cls), cls)
517+
if not hierarchy_only:
518+
for cls in PipelineOptions.__subclasses__():
519+
subset.setdefault(str(cls), cls)
520+
else:
521+
for cls in self.__class__.__mro__:
522+
if issubclass(cls, PipelineOptions):
523+
subset.setdefault(str(cls), cls)
515524
for cls in subset.values():
516525
cls._add_argparse_args(parser) # pylint: disable=protected-access
517526
if add_extra_args_fn:
@@ -562,7 +571,7 @@ def add_new_arg(arg, **kwargs):
562571
continue
563572
parsed_args, _ = parser.parse_known_args(self._flags)
564573
else:
565-
if unknown_args:
574+
if unknown_args and not hierarchy_only:
566575
_LOGGER.warning("Discarding unparseable args: %s", unknown_args)
567576
parsed_args = known_args
568577
result = vars(parsed_args)
@@ -580,7 +589,7 @@ def add_new_arg(arg, **kwargs):
580589
if overrides:
581590
if retain_unknown_options:
582591
result.update(overrides)
583-
else:
592+
elif not hierarchy_only:
584593
_LOGGER.warning("Discarding invalid overrides: %s", overrides)
585594

586595
return result

sdks/python/apache_beam/options/pipeline_options_test.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,11 @@ def _add_argparse_args(cls, parser):
204204
parser.add_argument(
205205
'--fake_multi_option', action='append', help='fake multi option')
206206

207+
class FakeSubclassOptions(FakeOptions):
208+
@classmethod
209+
def _add_argparse_args(cls, parser):
210+
parser.add_argument('--fake_sub_option', help='fake option')
211+
207212
@parameterized.expand(TEST_CASES)
208213
def test_display_data(self, flags, _, display_data):
209214
options = PipelineOptions(flags=flags)
@@ -238,6 +243,28 @@ def test_get_all_options(self, flags, expected, _):
238243
options.view_as(PipelineOptionsTest.MockOptions).mock_multi_option,
239244
expected['mock_multi_option'])
240245

246+
def test_get_superclass_options(self):
247+
flags = [
248+
"--mock_option",
249+
"mock",
250+
"--fake_option",
251+
"fake",
252+
"--fake_sub_option",
253+
"fake_sub"
254+
]
255+
options = PipelineOptions(flags=flags).view_as(
256+
PipelineOptionsTest.FakeSubclassOptions)
257+
items = options.get_all_options(hierarchy_only=True).items()
258+
print(items)
259+
self.assertTrue(('fake_option', 'fake') in items)
260+
self.assertTrue(('fake_sub_option', 'fake_sub') in items)
261+
self.assertFalse(('mock_option', 'mock') in items)
262+
items = options.view_as(PipelineOptionsTest.MockOptions).get_all_options(
263+
hierarchy_only=True).items()
264+
self.assertFalse(('fake_option', 'fake') in items)
265+
self.assertFalse(('fake_sub_option', 'fake_sub') in items)
266+
self.assertTrue(('mock_option', 'mock') in items)
267+
241268
@parameterized.expand(TEST_CASES)
242269
def test_subclasses_of_pipeline_options_can_be_instantiated(
243270
self, flags, expected, _):

sdks/python/apache_beam/runners/dask/dask_runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ def run_pipeline(self, pipeline, options):
236236
'DaskRunner is not available. Please install apache_beam[dask].')
237237

238238
dask_options = options.view_as(DaskOptions).get_all_options(
239-
drop_default=True)
239+
drop_default=True, hierarchy_only=True)
240240
bag_kwargs = DaskOptions._extract_bag_kwargs(dask_options)
241241
client = ddist.Client(**dask_options)
242242

0 commit comments

Comments
 (0)