Skip to content

Commit 8e70cba

Browse files
authored
Pass options in DaskOptions inheritance hierarchy only for Dask runner (#37101)
* Pass options in DaskOptions inheritance hierarchy only for Dask runner * address comments
1 parent ccad976 commit 8e70cba

File tree

3 files changed

+28
-8
lines changed

3 files changed

+28
-8
lines changed

sdks/python/apache_beam/options/pipeline_options.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -486,11 +486,12 @@ def get_all_options(
486486
drop_default=False,
487487
add_extra_args_fn: Optional[Callable[[_BeamArgumentParser], None]] = None,
488488
retain_unknown_options=False,
489-
display_warnings=False) -> Dict[str, Any]:
489+
display_warnings=False,
490+
current_only=False,
491+
) -> Dict[str, Any]:
490492
"""Returns a dictionary of all defined arguments.
491493
492-
Returns a dictionary of all defined arguments (arguments that are defined in
493-
any subclass of PipelineOptions) into a dictionary.
494+
Returns a dictionary of all defined arguments into a dictionary.
494495
495496
Args:
496497
drop_default: If set to true, options that are equal to their default
@@ -500,6 +501,9 @@ def get_all_options(
500501
retain_unknown_options: If set to true, options not recognized by any
501502
known pipeline options class will still be included in the result. If
502503
set to false, they will be discarded.
504+
current_only: If set to true, only returns options defined in this class.
505+
Otherwise, arguments that are defined in any subclass of PipelineOptions
506+
are returned (default).
503507
504508
Returns:
505509
Dictionary of all args and values.
@@ -510,8 +514,11 @@ def get_all_options(
510514
# instance of each subclass to avoid conflicts.
511515
subset = {}
512516
parser = _BeamArgumentParser(allow_abbrev=False)
513-
for cls in PipelineOptions.__subclasses__():
514-
subset.setdefault(str(cls), cls)
517+
if current_only:
518+
subset.setdefault(str(type(self)), type(self))
519+
else:
520+
for cls in PipelineOptions.__subclasses__():
521+
subset.setdefault(str(cls), cls)
515522
for cls in subset.values():
516523
cls._add_argparse_args(parser) # pylint: disable=protected-access
517524
if add_extra_args_fn:
@@ -562,7 +569,7 @@ def add_new_arg(arg, **kwargs):
562569
continue
563570
parsed_args, _ = parser.parse_known_args(self._flags)
564571
else:
565-
if unknown_args:
572+
if unknown_args and not current_only:
566573
_LOGGER.warning("Discarding unparseable args: %s", unknown_args)
567574
parsed_args = known_args
568575
result = vars(parsed_args)
@@ -580,7 +587,7 @@ def add_new_arg(arg, **kwargs):
580587
if overrides:
581588
if retain_unknown_options:
582589
result.update(overrides)
583-
else:
590+
elif not current_only:
584591
_LOGGER.warning("Discarding invalid overrides: %s", overrides)
585592

586593
return result

sdks/python/apache_beam/options/pipeline_options_test.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,19 @@ def test_get_all_options(self, flags, expected, _):
238238
options.view_as(PipelineOptionsTest.MockOptions).mock_multi_option,
239239
expected['mock_multi_option'])
240240

241+
def test_get_superclass_options(self):
242+
flags = ["--mock_option", "mock", "--fake_option", "fake"]
243+
options = PipelineOptions(flags=flags).view_as(
244+
PipelineOptionsTest.FakeOptions)
245+
items = options.get_all_options(current_only=True).items()
246+
print(items)
247+
self.assertTrue(('fake_option', 'fake') in items)
248+
self.assertFalse(('mock_option', 'mock') in items)
249+
items = options.view_as(PipelineOptionsTest.MockOptions).get_all_options(
250+
current_only=True).items()
251+
self.assertFalse(('fake_option', 'fake') in items)
252+
self.assertTrue(('mock_option', 'mock') in items)
253+
241254
@parameterized.expand(TEST_CASES)
242255
def test_subclasses_of_pipeline_options_can_be_instantiated(
243256
self, flags, expected, _):

sdks/python/apache_beam/runners/dask/dask_runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ def run_pipeline(self, pipeline, options):
236236
'DaskRunner is not available. Please install apache_beam[dask].')
237237

238238
dask_options = options.view_as(DaskOptions).get_all_options(
239-
drop_default=True)
239+
drop_default=True, current_only=True)
240240
bag_kwargs = DaskOptions._extract_bag_kwargs(dask_options)
241241
client = ddist.Client(**dask_options)
242242

0 commit comments

Comments
 (0)