Skip to content

Commit 9a5e5b8

Browse files
authored
Raise the Runtime error when DoFn.process uses both yield and return (#25743)
Co-authored-by: xqhu <[email protected]>
1 parent 5f9bf8b commit 9a5e5b8

File tree

2 files changed

+187
-7
lines changed

2 files changed

+187
-7
lines changed

sdks/python/apache_beam/transforms/core.py

Lines changed: 74 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import traceback
2929
import types
3030
import typing
31+
from itertools import dropwhile
3132

3233
from apache_beam import coders
3334
from apache_beam import pvalue
@@ -1387,6 +1388,59 @@ def partition_for(self, element, num_partitions, *args, **kwargs):
13871388
return self._fn(element, num_partitions, *args, **kwargs)
13881389

13891390

1391+
def _get_function_body_without_inners(func):
1392+
source_lines = inspect.getsourcelines(func)[0]
1393+
source_lines = dropwhile(lambda x: x.startswith("@"), source_lines)
1394+
def_line = next(source_lines).strip()
1395+
if def_line.startswith("def ") and def_line.endswith(":"):
1396+
first_line = next(source_lines)
1397+
indentation = len(first_line) - len(first_line.lstrip())
1398+
final_lines = [first_line[indentation:]]
1399+
1400+
skip_inner_def = False
1401+
if first_line[indentation:].startswith("def "):
1402+
skip_inner_def = True
1403+
for line in source_lines:
1404+
line_indentation = len(line) - len(line.lstrip())
1405+
1406+
if line[indentation:].startswith("def "):
1407+
skip_inner_def = True
1408+
continue
1409+
1410+
if skip_inner_def and line_indentation == indentation:
1411+
skip_inner_def = False
1412+
1413+
if skip_inner_def and line_indentation > indentation:
1414+
continue
1415+
final_lines.append(line[indentation:])
1416+
1417+
return "".join(final_lines)
1418+
else:
1419+
return def_line.rsplit(":")[-1].strip()
1420+
1421+
1422+
def _check_fn_use_yield_and_return(fn):
1423+
if isinstance(fn, types.BuiltinFunctionType):
1424+
return False
1425+
try:
1426+
source_code = _get_function_body_without_inners(fn)
1427+
has_yield = False
1428+
has_return = False
1429+
for line in source_code.split("\n"):
1430+
if line.lstrip().startswith("yield ") or line.lstrip().startswith(
1431+
"yield("):
1432+
has_yield = True
1433+
if line.lstrip().startswith("return ") or line.lstrip().startswith(
1434+
"return("):
1435+
has_return = True
1436+
if has_yield and has_return:
1437+
return True
1438+
return False
1439+
except Exception as e:
1440+
_LOGGER.debug(str(e))
1441+
return False
1442+
1443+
13901444
class ParDo(PTransformWithSideInputs):
13911445
"""A :class:`ParDo` transform.
13921446
@@ -1427,6 +1481,14 @@ def __init__(self, fn, *args, **kwargs):
14271481
if not isinstance(self.fn, DoFn):
14281482
raise TypeError('ParDo must be called with a DoFn instance.')
14291483

1484+
# DoFn.process cannot allow both return and yield
1485+
if _check_fn_use_yield_and_return(self.fn.process):
1486+
_LOGGER.warning(
1487+
'Using yield and return in the process method '
1488+
'of %s can lead to unexpected behavior, see:'
1489+
'https://github.com/apache/beam/issues/22969.',
1490+
self.fn.__class__)
1491+
14301492
# Validate the DoFn by creating a DoFnSignature
14311493
from apache_beam.runners.common import DoFnSignature
14321494
self._signature = DoFnSignature(self.fn)
@@ -2663,6 +2725,7 @@ def from_runner_api_parameter(unused_ptransform, combine_payload, context):
26632725

26642726
class CombineValuesDoFn(DoFn):
26652727
"""DoFn for performing per-key Combine transforms."""
2728+
26662729
def __init__(
26672730
self,
26682731
input_pcoll_type,
@@ -2725,6 +2788,7 @@ def default_type_hints(self):
27252788

27262789

27272790
class _CombinePerKeyWithHotKeyFanout(PTransform):
2791+
27282792
def __init__(
27292793
self,
27302794
combine_fn, # type: CombineFn
@@ -2939,11 +3003,12 @@ class GroupBy(PTransform):
29393003
The GroupBy operation can be made into an aggregating operation by invoking
29403004
its `aggregate_field` method.
29413005
"""
3006+
29423007
def __init__(
29433008
self,
29443009
*fields, # type: typing.Union[str, typing.Callable]
29453010
**kwargs # type: typing.Union[str, typing.Callable]
2946-
):
3011+
):
29473012
if len(fields) == 1 and not kwargs:
29483013
self._force_tuple_keys = False
29493014
name = fields[0] if isinstance(fields[0], str) else 'key'
@@ -2966,7 +3031,7 @@ def aggregate_field(
29663031
field, # type: typing.Union[str, typing.Callable]
29673032
combine_fn, # type: typing.Union[typing.Callable, CombineFn]
29683033
dest, # type: str
2969-
):
3034+
):
29703035
"""Returns a grouping operation that also aggregates grouped values.
29713036
29723037
Args:
@@ -3054,7 +3119,7 @@ def aggregate_field(
30543119
field, # type: typing.Union[str, typing.Callable]
30553120
combine_fn, # type: typing.Union[typing.Callable, CombineFn]
30563121
dest, # type: str
3057-
):
3122+
):
30583123
field = _expr_to_callable(field, 0)
30593124
return _GroupAndAggregate(
30603125
self._grouping, list(self._aggregations) + [(field, combine_fn, dest)])
@@ -3096,10 +3161,12 @@ class Select(PTransform):
30963161
30973162
pcoll | beam.Map(lambda x: beam.Row(a=x.a, b=foo(x)))
30983163
"""
3099-
def __init__(self,
3100-
*args, # type: typing.Union[str, typing.Callable]
3101-
**kwargs # type: typing.Union[str, typing.Callable]
3102-
):
3164+
3165+
def __init__(
3166+
self,
3167+
*args, # type: typing.Union[str, typing.Callable]
3168+
**kwargs # type: typing.Union[str, typing.Callable]
3169+
):
31033170
self._fields = [(
31043171
expr if isinstance(expr, str) else 'arg%02d' % ix,
31053172
_expr_to_callable(expr, ix)) for (ix, expr) in enumerate(args)
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
"""Unit tests for the core python file."""
19+
# pytype: skip-file
20+
21+
import logging
22+
import unittest
23+
24+
import pytest
25+
26+
import apache_beam as beam
27+
28+
29+
class TestDoFn1(beam.DoFn):
30+
def process(self, element):
31+
yield element
32+
33+
34+
class TestDoFn2(beam.DoFn):
35+
def process(self, element):
36+
def inner_func(x):
37+
yield x
38+
39+
return inner_func(element)
40+
41+
42+
class TestDoFn3(beam.DoFn):
43+
"""mixing return and yield is not allowed"""
44+
def process(self, element):
45+
if not element:
46+
return -1
47+
yield element
48+
49+
50+
class TestDoFn4(beam.DoFn):
51+
"""test the variable name containing return"""
52+
def process(self, element):
53+
my_return = element
54+
yield my_return
55+
56+
57+
class TestDoFn5(beam.DoFn):
58+
"""test the variable name containing yield"""
59+
def process(self, element):
60+
my_yield = element
61+
return my_yield
62+
63+
64+
class TestDoFn6(beam.DoFn):
65+
"""test the variable name containing return"""
66+
def process(self, element):
67+
return_test = element
68+
yield return_test
69+
70+
71+
class TestDoFn7(beam.DoFn):
72+
"""test the variable name containing yield"""
73+
def process(self, element):
74+
yield_test = element
75+
return yield_test
76+
77+
78+
class TestDoFn8(beam.DoFn):
79+
"""test the code containing yield and yield from"""
80+
def process(self, element):
81+
if not element:
82+
yield from [1, 2, 3]
83+
else:
84+
yield element
85+
86+
87+
class CreateTest(unittest.TestCase):
88+
@pytest.fixture(autouse=True)
89+
def inject_fixtures(self, caplog):
90+
self._caplog = caplog
91+
92+
def test_dofn_with_yield_and_return(self):
93+
warning_text = 'Using yield and return'
94+
95+
with self._caplog.at_level(logging.WARNING):
96+
assert beam.ParDo(sum)
97+
assert beam.ParDo(TestDoFn1())
98+
assert beam.ParDo(TestDoFn2())
99+
assert beam.ParDo(TestDoFn4())
100+
assert beam.ParDo(TestDoFn5())
101+
assert beam.ParDo(TestDoFn6())
102+
assert beam.ParDo(TestDoFn7())
103+
assert beam.ParDo(TestDoFn8())
104+
assert warning_text not in self._caplog.text
105+
106+
with self._caplog.at_level(logging.WARNING):
107+
beam.ParDo(TestDoFn3())
108+
assert warning_text in self._caplog.text
109+
110+
111+
if __name__ == '__main__':
112+
logging.getLogger().setLevel(logging.INFO)
113+
unittest.main()

0 commit comments

Comments
 (0)