Skip to content
Merged
98 changes: 83 additions & 15 deletions sdks/python/apache_beam/transforms/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import traceback
import types
import typing
from itertools import dropwhile

from apache_beam import coders
from apache_beam import pvalue
Expand Down Expand Up @@ -1387,6 +1388,56 @@ def partition_for(self, element, num_partitions, *args, **kwargs):
return self._fn(element, num_partitions, *args, **kwargs)


def _get_function_body_without_inners(func):
source_lines = inspect.getsourcelines(func)[0]
source_lines = dropwhile(lambda x: x.startswith("@"), source_lines)
def_line = next(source_lines).strip()
if def_line.startswith("def ") and def_line.endswith(":"):
first_line = next(source_lines)
indentation = len(first_line) - len(first_line.lstrip())
final_lines = [first_line[indentation:]]

skip_inner_def = False
if first_line[indentation:].startswith("def "):
skip_inner_def = True
for line in source_lines:
line_indentation = len(line) - len(line.lstrip())

if line[indentation:].startswith("def "):
skip_inner_def = True
continue

if skip_inner_def and line_indentation == indentation:
skip_inner_def = False

if skip_inner_def and line_indentation > indentation:
continue
final_lines.append(line[indentation:])

return "".join(final_lines)
else:
return def_line.rsplit(":")[-1].strip()


def _check_fn_use_yield_and_return(fn):
if isinstance(fn, types.BuiltinFunctionType):
return False
try:
source_code = _get_function_body_without_inners(fn)
has_yield = False
has_return = False
for line in source_code.split("\n"):
if line.lstrip().startswith("yield"):
has_yield = True
if line.lstrip().startswith("return"):
has_return = True
if has_yield and has_return:
return True
return False
except TypeError:
return False


class ParDo(PTransformWithSideInputs):
"""A :class:`ParDo` transform.

Expand Down Expand Up @@ -1427,6 +1478,16 @@ def __init__(self, fn, *args, **kwargs):
if not isinstance(self.fn, DoFn):
raise TypeError('ParDo must be called with a DoFn instance.')

# DoFn.process cannot allow both return and yield
if _check_fn_use_yield_and_return(self.fn.process):
_LOGGER.warning(
'The yield and return statements in the process method '
'of %s can not be mixed.'
'We recommend to use `yield` for emitting individual '
' elements and `yield from` for emitting the content '
'of entire iterables.',
self.fn.__class__)

# Validate the DoFn by creating a DoFnSignature
from apache_beam.runners.common import DoFnSignature
self._signature = DoFnSignature(self.fn)
Expand Down Expand Up @@ -2663,6 +2724,7 @@ def from_runner_api_parameter(unused_ptransform, combine_payload, context):

class CombineValuesDoFn(DoFn):
"""DoFn for performing per-key Combine transforms."""

def __init__(
self,
input_pcoll_type,
Expand Down Expand Up @@ -2725,6 +2787,7 @@ def default_type_hints(self):


class _CombinePerKeyWithHotKeyFanout(PTransform):

def __init__(
self,
combine_fn, # type: CombineFn
Expand Down Expand Up @@ -2939,11 +3002,12 @@ class GroupBy(PTransform):
The GroupBy operation can be made into an aggregating operation by invoking
its `aggregate_field` method.
"""

def __init__(
self,
*fields, # type: typing.Union[str, typing.Callable]
**kwargs # type: typing.Union[str, typing.Callable]
):
):
if len(fields) == 1 and not kwargs:
self._force_tuple_keys = False
name = fields[0] if isinstance(fields[0], str) else 'key'
Expand All @@ -2966,7 +3030,7 @@ def aggregate_field(
field, # type: typing.Union[str, typing.Callable]
combine_fn, # type: typing.Union[typing.Callable, CombineFn]
dest, # type: str
):
):
"""Returns a grouping operation that also aggregates grouped values.

Args:
Expand Down Expand Up @@ -3054,7 +3118,7 @@ def aggregate_field(
field, # type: typing.Union[str, typing.Callable]
combine_fn, # type: typing.Union[typing.Callable, CombineFn]
dest, # type: str
):
):
field = _expr_to_callable(field, 0)
return _GroupAndAggregate(
self._grouping, list(self._aggregations) + [(field, combine_fn, dest)])
Expand Down Expand Up @@ -3096,10 +3160,12 @@ class Select(PTransform):

pcoll | beam.Map(lambda x: beam.Row(a=x.a, b=foo(x)))
"""
def __init__(self,
*args, # type: typing.Union[str, typing.Callable]
**kwargs # type: typing.Union[str, typing.Callable]
):

def __init__(
self,
*args, # type: typing.Union[str, typing.Callable]
**kwargs # type: typing.Union[str, typing.Callable]
):
self._fields = [(
expr if isinstance(expr, str) else 'arg%02d' % ix,
_expr_to_callable(expr, ix)) for (ix, expr) in enumerate(args)
Expand Down Expand Up @@ -3161,14 +3227,16 @@ def expand(self, pcoll):


class Windowing(object):
def __init__(self,
windowfn, # type: WindowFn
triggerfn=None, # type: typing.Optional[TriggerFn]
accumulation_mode=None, # type: typing.Optional[beam_runner_api_pb2.AccumulationMode.Enum]
timestamp_combiner=None, # type: typing.Optional[beam_runner_api_pb2.OutputTime.Enum]
allowed_lateness=0, # type: typing.Union[int, float]
environment_id=None, # type: typing.Optional[str]
):

def __init__(
self,
windowfn, # type: WindowFn
triggerfn=None, # type: typing.Optional[TriggerFn]
accumulation_mode=None, # type: typing.Optional[beam_runner_api_pb2.AccumulationMode.Enum]
timestamp_combiner=None, # type: typing.Optional[beam_runner_api_pb2.OutputTime.Enum]
allowed_lateness=0, # type: typing.Union[int, float]
environment_id=None, # type: typing.Optional[str]
):
"""Class representing the window strategy.

Args:
Expand Down
82 changes: 82 additions & 0 deletions sdks/python/apache_beam/transforms/core_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""Unit tests for the core python file."""
# pytype: skip-file

import logging
import unittest

import pytest

import apache_beam as beam


class TestDoFn1(beam.DoFn):
def process(self, element):
yield element


class TestDoFn2(beam.DoFn):
def process(self, element):
def inner_func(x):
yield x

return inner_func(element)


class TestDoFn3(beam.DoFn):
"""mixing return and yield is not allowed"""
def process(self, element):
if not element:
return -1
yield element


class TestDoFn4(beam.DoFn):
"""test the variable name containing return"""
def process(self, element):
my_return = element
yield my_return


class TestDoFn5(beam.DoFn):
"""test the variable name containing yield"""
def process(self, element):
my_yield = element
return my_yield


class CreateTest(unittest.TestCase):
@pytest.fixture(autouse=True)
def inject_fixtures(self, caplog):
self._caplog = caplog

def test_dofn_with_yield_and_return(self):
assert beam.ParDo(sum)
assert beam.ParDo(TestDoFn1())
assert beam.ParDo(TestDoFn2())
assert beam.ParDo(TestDoFn4())
assert beam.ParDo(TestDoFn5())
with self._caplog.at_level(logging.WARNING):
beam.ParDo(TestDoFn3())
assert 'The yield and return statements in' in self._caplog.text


if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
unittest.main()