Skip to content

Commit b60ce5b

Browse files
Add type checking to confirm that the flatten inputs are actually pco… (#35874)
* Add type checking to confirm that the flatten inputs are actually pcollections * Allow iterables of non-pcollections to flatten again * Update sdks/python/apache_beam/transforms/core.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent d2e1129 commit b60ce5b

File tree

2 files changed

+18
-0
lines changed

2 files changed

+18
-0
lines changed

sdks/python/apache_beam/transforms/core.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3866,6 +3866,15 @@ def _extract_input_pvalues(self, pvalueish):
38663866
raise ValueError(
38673867
'Input to Flatten must be an iterable. '
38683868
'Got a value of type %s instead.' % type(pvalueish))
3869+
3870+
# Spot check to see if any of the items are iterables of PCollections
3871+
# and raise an error if so. This is always a user-error
3872+
for idx, item in enumerate(pvalueish):
3873+
if isinstance(item, (list, tuple)) and any(
3874+
isinstance(sub_item, pvalue.PCollection) for sub_item in item):
3875+
raise TypeError(
3876+
'Inputs to Flatten cannot include an iterable of PCollections. '
3877+
f'(input at index {idx}: "{item}")')
38693878
return pvalueish, pvalueish
38703879

38713880
def expand(self, pcolls):

sdks/python/apache_beam/typehints/typecheck_test.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,15 @@ def process(self, element, *args, **kwargs):
179179
(self.p | beam.Create(['1', '1']) | beam.ParDo(ToInt()))
180180
self.p.run().wait_until_finish()
181181

182+
def test_bad_flatten_input(self):
183+
with self.assertRaisesRegex(
184+
TypeError,
185+
"Inputs to Flatten cannot include an iterable of PCollections. "):
186+
with beam.Pipeline() as p:
187+
pc = p | beam.Create([1, 1])
188+
flatten_inputs = [pc, (pc, )]
189+
flatten_inputs | beam.Flatten()
190+
182191
def test_do_fn_returning_non_iterable_throws_error(self):
183192
# This function is incorrect because it returns a non-iterable object
184193
def incorrect_par_do_fn(x):

0 commit comments

Comments
 (0)