diff --git a/sdks/python/apache_beam/transforms/core.py b/sdks/python/apache_beam/transforms/core.py index 47aaeff43a6f..6260975b32c9 100644 --- a/sdks/python/apache_beam/transforms/core.py +++ b/sdks/python/apache_beam/transforms/core.py @@ -28,6 +28,7 @@ import traceback import types import typing +from itertools import dropwhile from apache_beam import coders from apache_beam import pvalue @@ -1387,6 +1388,59 @@ def partition_for(self, element, num_partitions, *args, **kwargs): return self._fn(element, num_partitions, *args, **kwargs) +def _get_function_body_without_inners(func): + source_lines = inspect.getsourcelines(func)[0] + source_lines = dropwhile(lambda x: x.startswith("@"), source_lines) + def_line = next(source_lines).strip() + if def_line.startswith("def ") and def_line.endswith(":"): + first_line = next(source_lines) + indentation = len(first_line) - len(first_line.lstrip()) + final_lines = [first_line[indentation:]] + + skip_inner_def = False + if first_line[indentation:].startswith("def "): + skip_inner_def = True + for line in source_lines: + line_indentation = len(line) - len(line.lstrip()) + + if line[indentation:].startswith("def "): + skip_inner_def = True + continue + + if skip_inner_def and line_indentation == indentation: + skip_inner_def = False + + if skip_inner_def and line_indentation > indentation: + continue + final_lines.append(line[indentation:]) + + return "".join(final_lines) + else: + return def_line.rsplit(":")[-1].strip() + + +def _check_fn_use_yield_and_return(fn): + if isinstance(fn, types.BuiltinFunctionType): + return False + try: + source_code = _get_function_body_without_inners(fn) + has_yield = False + has_return = False + for line in source_code.split("\n"): + if line.lstrip().startswith("yield ") or line.lstrip().startswith( + "yield("): + has_yield = True + if line.lstrip().startswith("return ") or line.lstrip().startswith( + "return("): + has_return = True + if has_yield and has_return: + return True + return False + except Exception as e: + _LOGGER.debug(str(e)) + return False + + class ParDo(PTransformWithSideInputs): """A :class:`ParDo` transform. @@ -1427,6 +1481,14 @@ def __init__(self, fn, *args, **kwargs): if not isinstance(self.fn, DoFn): raise TypeError('ParDo must be called with a DoFn instance.') + # DoFn.process cannot allow both return and yield + if _check_fn_use_yield_and_return(self.fn.process): + _LOGGER.warning( + 'Using yield and return in the process method ' + 'of %s can lead to unexpected behavior, see:' + 'https://github.com/apache/beam/issues/22969.', + self.fn.__class__) + # Validate the DoFn by creating a DoFnSignature from apache_beam.runners.common import DoFnSignature self._signature = DoFnSignature(self.fn) @@ -2663,6 +2725,7 @@ def from_runner_api_parameter(unused_ptransform, combine_payload, context): class CombineValuesDoFn(DoFn): """DoFn for performing per-key Combine transforms.""" + def __init__( self, input_pcoll_type, @@ -2725,6 +2788,7 @@ def default_type_hints(self): class _CombinePerKeyWithHotKeyFanout(PTransform): + def __init__( self, combine_fn, # type: CombineFn @@ -2939,11 +3003,12 @@ class GroupBy(PTransform): The GroupBy operation can be made into an aggregating operation by invoking its `aggregate_field` method. """ + def __init__( self, *fields, # type: typing.Union[str, typing.Callable] **kwargs # type: typing.Union[str, typing.Callable] - ): + ): if len(fields) == 1 and not kwargs: self._force_tuple_keys = False name = fields[0] if isinstance(fields[0], str) else 'key' @@ -2966,7 +3031,7 @@ def aggregate_field( field, # type: typing.Union[str, typing.Callable] combine_fn, # type: typing.Union[typing.Callable, CombineFn] dest, # type: str - ): + ): """Returns a grouping operation that also aggregates grouped values. Args: @@ -3054,7 +3119,7 @@ def aggregate_field( field, # type: typing.Union[str, typing.Callable] combine_fn, # type: typing.Union[typing.Callable, CombineFn] dest, # type: str - ): + ): field = _expr_to_callable(field, 0) return _GroupAndAggregate( self._grouping, list(self._aggregations) + [(field, combine_fn, dest)]) @@ -3096,10 +3161,12 @@ class Select(PTransform): pcoll | beam.Map(lambda x: beam.Row(a=x.a, b=foo(x))) """ - def __init__(self, - *args, # type: typing.Union[str, typing.Callable] - **kwargs # type: typing.Union[str, typing.Callable] - ): + + def __init__( + self, + *args, # type: typing.Union[str, typing.Callable] + **kwargs # type: typing.Union[str, typing.Callable] + ): self._fields = [( expr if isinstance(expr, str) else 'arg%02d' % ix, _expr_to_callable(expr, ix)) for (ix, expr) in enumerate(args) diff --git a/sdks/python/apache_beam/transforms/core_test.py b/sdks/python/apache_beam/transforms/core_test.py new file mode 100644 index 000000000000..0fba28266138 --- /dev/null +++ b/sdks/python/apache_beam/transforms/core_test.py @@ -0,0 +1,113 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Unit tests for the core python file.""" +# pytype: skip-file + +import logging +import unittest + +import pytest + +import apache_beam as beam + + +class TestDoFn1(beam.DoFn): + def process(self, element): + yield element + + +class TestDoFn2(beam.DoFn): + def process(self, element): + def inner_func(x): + yield x + + return inner_func(element) + + +class TestDoFn3(beam.DoFn): + """mixing return and yield is not allowed""" + def process(self, element): + if not element: + return -1 + yield element + + +class TestDoFn4(beam.DoFn): + """test the variable name containing return""" + def process(self, element): + my_return = element + yield my_return + + +class TestDoFn5(beam.DoFn): + """test the variable name containing yield""" + def process(self, element): + my_yield = element + return my_yield + + +class TestDoFn6(beam.DoFn): + """test the variable name containing return""" + def process(self, element): + return_test = element + yield return_test + + +class TestDoFn7(beam.DoFn): + """test the variable name containing yield""" + def process(self, element): + yield_test = element + return yield_test + + +class TestDoFn8(beam.DoFn): + """test the code containing yield and yield from""" + def process(self, element): + if not element: + yield from [1, 2, 3] + else: + yield element + + +class CreateTest(unittest.TestCase): + @pytest.fixture(autouse=True) + def inject_fixtures(self, caplog): + self._caplog = caplog + + def test_dofn_with_yield_and_return(self): + warning_text = 'Using yield and return' + + with self._caplog.at_level(logging.WARNING): + assert beam.ParDo(sum) + assert beam.ParDo(TestDoFn1()) + assert beam.ParDo(TestDoFn2()) + assert beam.ParDo(TestDoFn4()) + assert beam.ParDo(TestDoFn5()) + assert beam.ParDo(TestDoFn6()) + assert beam.ParDo(TestDoFn7()) + assert beam.ParDo(TestDoFn8()) + assert warning_text not in self._caplog.text + + with self._caplog.at_level(logging.WARNING): + beam.ParDo(TestDoFn3()) + assert warning_text in self._caplog.text + + +if __name__ == '__main__': + logging.getLogger().setLevel(logging.INFO) + unittest.main()