2828import traceback
2929import types
3030import typing
31+ from itertools import dropwhile
3132
3233from apache_beam import coders
3334from apache_beam import pvalue
@@ -1387,6 +1388,59 @@ def partition_for(self, element, num_partitions, *args, **kwargs):
13871388 return self ._fn (element , num_partitions , * args , ** kwargs )
13881389
13891390
1391+ def _get_function_body_without_inners (func ):
1392+ source_lines = inspect .getsourcelines (func )[0 ]
1393+ source_lines = dropwhile (lambda x : x .startswith ("@" ), source_lines )
1394+ def_line = next (source_lines ).strip ()
1395+ if def_line .startswith ("def " ) and def_line .endswith (":" ):
1396+ first_line = next (source_lines )
1397+ indentation = len (first_line ) - len (first_line .lstrip ())
1398+ final_lines = [first_line [indentation :]]
1399+
1400+ skip_inner_def = False
1401+ if first_line [indentation :].startswith ("def " ):
1402+ skip_inner_def = True
1403+ for line in source_lines :
1404+ line_indentation = len (line ) - len (line .lstrip ())
1405+
1406+ if line [indentation :].startswith ("def " ):
1407+ skip_inner_def = True
1408+ continue
1409+
1410+ if skip_inner_def and line_indentation == indentation :
1411+ skip_inner_def = False
1412+
1413+ if skip_inner_def and line_indentation > indentation :
1414+ continue
1415+ final_lines .append (line [indentation :])
1416+
1417+ return "" .join (final_lines )
1418+ else :
1419+ return def_line .rsplit (":" )[- 1 ].strip ()
1420+
1421+
1422+ def _check_fn_use_yield_and_return (fn ):
1423+ if isinstance (fn , types .BuiltinFunctionType ):
1424+ return False
1425+ try :
1426+ source_code = _get_function_body_without_inners (fn )
1427+ has_yield = False
1428+ has_return = False
1429+ for line in source_code .split ("\n " ):
1430+ if line .lstrip ().startswith ("yield " ) or line .lstrip ().startswith (
1431+ "yield(" ):
1432+ has_yield = True
1433+ if line .lstrip ().startswith ("return " ) or line .lstrip ().startswith (
1434+ "return(" ):
1435+ has_return = True
1436+ if has_yield and has_return :
1437+ return True
1438+ return False
1439+ except Exception as e :
1440+ _LOGGER .debug (str (e ))
1441+ return False
1442+
1443+
13901444class ParDo (PTransformWithSideInputs ):
13911445 """A :class:`ParDo` transform.
13921446
@@ -1427,6 +1481,14 @@ def __init__(self, fn, *args, **kwargs):
14271481 if not isinstance (self .fn , DoFn ):
14281482 raise TypeError ('ParDo must be called with a DoFn instance.' )
14291483
1484+ # DoFn.process cannot allow both return and yield
1485+ if _check_fn_use_yield_and_return (self .fn .process ):
1486+ _LOGGER .warning (
1487+ 'Using yield and return in the process method '
1488+ 'of %s can lead to unexpected behavior, see:'
1489+ 'https://github.com/apache/beam/issues/22969.' ,
1490+ self .fn .__class__ )
1491+
14301492 # Validate the DoFn by creating a DoFnSignature
14311493 from apache_beam .runners .common import DoFnSignature
14321494 self ._signature = DoFnSignature (self .fn )
@@ -2663,6 +2725,7 @@ def from_runner_api_parameter(unused_ptransform, combine_payload, context):
26632725
26642726class CombineValuesDoFn (DoFn ):
26652727 """DoFn for performing per-key Combine transforms."""
2728+
26662729 def __init__ (
26672730 self ,
26682731 input_pcoll_type ,
@@ -2725,6 +2788,7 @@ def default_type_hints(self):
27252788
27262789
27272790class _CombinePerKeyWithHotKeyFanout (PTransform ):
2791+
27282792 def __init__ (
27292793 self ,
27302794 combine_fn , # type: CombineFn
@@ -2939,11 +3003,12 @@ class GroupBy(PTransform):
29393003 The GroupBy operation can be made into an aggregating operation by invoking
29403004 its `aggregate_field` method.
29413005 """
3006+
29423007 def __init__ (
29433008 self ,
29443009 * fields , # type: typing.Union[str, typing.Callable]
29453010 ** kwargs # type: typing.Union[str, typing.Callable]
2946- ):
3011+ ):
29473012 if len (fields ) == 1 and not kwargs :
29483013 self ._force_tuple_keys = False
29493014 name = fields [0 ] if isinstance (fields [0 ], str ) else 'key'
@@ -2966,7 +3031,7 @@ def aggregate_field(
29663031 field , # type: typing.Union[str, typing.Callable]
29673032 combine_fn , # type: typing.Union[typing.Callable, CombineFn]
29683033 dest , # type: str
2969- ):
3034+ ):
29703035 """Returns a grouping operation that also aggregates grouped values.
29713036
29723037 Args:
@@ -3054,7 +3119,7 @@ def aggregate_field(
30543119 field , # type: typing.Union[str, typing.Callable]
30553120 combine_fn , # type: typing.Union[typing.Callable, CombineFn]
30563121 dest , # type: str
3057- ):
3122+ ):
30583123 field = _expr_to_callable (field , 0 )
30593124 return _GroupAndAggregate (
30603125 self ._grouping , list (self ._aggregations ) + [(field , combine_fn , dest )])
@@ -3096,10 +3161,12 @@ class Select(PTransform):
30963161
30973162 pcoll | beam.Map(lambda x: beam.Row(a=x.a, b=foo(x)))
30983163 """
3099- def __init__ (self ,
3100- * args , # type: typing.Union[str, typing.Callable]
3101- ** kwargs # type: typing.Union[str, typing.Callable]
3102- ):
3164+
3165+ def __init__ (
3166+ self ,
3167+ * args , # type: typing.Union[str, typing.Callable]
3168+ ** kwargs # type: typing.Union[str, typing.Callable]
3169+ ):
31033170 self ._fields = [(
31043171 expr if isinstance (expr , str ) else 'arg%02d' % ix ,
31053172 _expr_to_callable (expr , ix )) for (ix , expr ) in enumerate (args )
0 commit comments