Skip to content

Commit b56f99d

Browse files
authored
Fix incorrect typehints generated by FlatMap with default identity function (#35164)
* create unit test * add some unit tests * fix bug where T is considered iterable * update strip_iterable to return Any for "stripped iterable" type of TypeVariable * remove typehint from identity function and add a test to test for proper typechecking * Move callablewrapp typehint test * Remove print * isort * isort * return any for yielded type of T
1 parent 094a315 commit b56f99d

File tree

6 files changed

+43
-1
lines changed

6 files changed

+43
-1
lines changed

sdks/python/apache_beam/transforms/core.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2016,7 +2016,7 @@ def to_runner_api(self, unused_context):
20162016
return beam_runner_api_pb2.FunctionSpec(urn=self._urn)
20172017

20182018

2019-
def identity(x: T) -> T:
2019+
def identity(x):
20202020
return x
20212021

20222022

@@ -2053,6 +2053,7 @@ def FlatMap(fn=identity, *args, **kwargs): # pylint: disable=invalid-name
20532053

20542054
pardo = ParDo(CallableWrapperDoFn(fn), *args, **kwargs)
20552055
pardo.label = label
2056+
20562057
return pardo
20572058

20582059

sdks/python/apache_beam/transforms/core_test.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,16 @@
2222
import os
2323
import tempfile
2424
import unittest
25+
from typing import TypeVar
2526

2627
import pytest
2728

2829
import apache_beam as beam
2930
from apache_beam.testing.util import assert_that
3031
from apache_beam.testing.util import equal_to
3132
from apache_beam.transforms.window import FixedWindows
33+
from apache_beam.typehints import TypeCheckError
34+
from apache_beam.typehints import typehints
3235

3336
RETURN_NONE_PARTIAL_WARNING = "No iterator is returned"
3437

@@ -279,6 +282,16 @@ def failure_callback(e, el):
279282
self.assertFalse(os.path.isfile(tmp_path))
280283

281284

285+
def test_callablewrapper_typehint():
286+
T = TypeVar("T")
287+
288+
def identity(x: T) -> T:
289+
return x
290+
291+
dofn = beam.core.CallableWrapperDoFn(identity)
292+
assert dofn.get_type_hints().strip_iterable()[1][0][0] == typehints.Any
293+
294+
282295
class FlatMapTest(unittest.TestCase):
283296
def test_default(self):
284297

@@ -289,6 +302,25 @@ def test_default(self):
289302
| beam.FlatMap())
290303
assert_that(letters, equal_to(['a', 'b', 'c', 'd', 'e', 'f']))
291304

305+
def test_default_identity_function_with_typehint(self):
306+
with beam.Pipeline() as pipeline:
307+
letters = (
308+
pipeline
309+
| beam.Create([["abc"]], reshuffle=False)
310+
| beam.FlatMap()
311+
| beam.Map(lambda s: s.upper()).with_input_types(str))
312+
313+
assert_that(letters, equal_to(["ABC"]))
314+
315+
def test_typecheck_with_default(self):
316+
with pytest.raises(TypeCheckError):
317+
with beam.Pipeline() as pipeline:
318+
_ = (
319+
pipeline
320+
| beam.Create([[1, 2, 3]], reshuffle=False)
321+
| beam.FlatMap()
322+
| beam.Map(lambda s: s.upper()).with_input_types(str))
323+
292324

293325
if __name__ == '__main__':
294326
logging.getLogger().setLevel(logging.INFO)

sdks/python/apache_beam/typehints/decorators.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,11 @@ def strip_iterable(self) -> 'IOTypeHints':
424424
output_type = types[0]
425425
except ValueError:
426426
pass
427+
if isinstance(output_type, typehints.TypeVariable):
428+
# We don't know what T yields, so we just assume Any.
429+
return self._replace(
430+
output_types=((typehints.Any, ), {}),
431+
origin=self._make_origin([self], tb=False, msg=['strip_iterable()']))
427432

428433
yielded_type = typehints.get_yielded_type(output_type)
429434
return self._replace(

sdks/python/apache_beam/typehints/decorators_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ def _test_strip_iterable_fail(self, before):
131131
def test_strip_iterable(self):
132132
self._test_strip_iterable(None, None)
133133
self._test_strip_iterable(typehints.Any, typehints.Any)
134+
self._test_strip_iterable(T, typehints.Any)
134135
self._test_strip_iterable(typehints.Iterable[str], str)
135136
self._test_strip_iterable(typehints.List[str], str)
136137
self._test_strip_iterable(typehints.Iterator[str], str)

sdks/python/apache_beam/typehints/typehints.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1563,6 +1563,8 @@ def get_yielded_type(type_hint):
15631563
Raises:
15641564
ValueError if not iterable.
15651565
"""
1566+
if isinstance(type_hint, typing.TypeVar):
1567+
return typing.Any
15661568
if isinstance(type_hint, AnyTypeConstraint):
15671569
return type_hint
15681570
if is_consistent_with(type_hint, Iterator[Any]):

sdks/python/apache_beam/typehints/typehints_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1654,6 +1654,7 @@ def test_iterables(self):
16541654
typehints.get_yielded_type(typehints.Tuple[int, str]))
16551655
self.assertEqual(int, typehints.get_yielded_type(typehints.Set[int]))
16561656
self.assertEqual(int, typehints.get_yielded_type(typehints.FrozenSet[int]))
1657+
self.assertEqual(typing.Any, typehints.get_yielded_type(T))
16571658
self.assertEqual(
16581659
typehints.Union[int, str],
16591660
typehints.get_yielded_type(

0 commit comments

Comments
 (0)