diff --git a/sdks/python/apache_beam/pipeline.py b/sdks/python/apache_beam/pipeline.py index a6080f2f3e7f..3cce2c5bb773 100644 --- a/sdks/python/apache_beam/pipeline.py +++ b/sdks/python/apache_beam/pipeline.py @@ -990,7 +990,11 @@ def expand(self, pcoll): input_element_types_tuple[0] if len(input_element_types_tuple) == 1 else typehints.Union[input_element_types_tuple]) type_hints = transform.get_type_hints() - declared_output_type = type_hints.simple_output_type(transform.label) + if not result_pcollection.tag: + declared_output_type = type_hints.simple_output_type(transform.label) + else: + declared_output_type = type_hints.tagged_output_types().get( + result_pcollection.tag, typehints.Any) if declared_output_type: input_types = type_hints.input_types diff --git a/sdks/python/apache_beam/typehints/tagged_output_typehints_test.py b/sdks/python/apache_beam/typehints/tagged_output_typehints_test.py index c06f68fb88a4..a81579ada48b 100644 --- a/sdks/python/apache_beam/typehints/tagged_output_typehints_test.py +++ b/sdks/python/apache_beam/typehints/tagged_output_typehints_test.py @@ -42,6 +42,7 @@ def fn(element) -> int | TaggedOutput[Literal['errors'], str]: import apache_beam as beam from apache_beam.pvalue import TaggedOutput from apache_beam.typehints import with_output_types +from apache_beam.typehints import Any from apache_beam.typehints.decorators import IOTypeHints @@ -352,5 +353,88 @@ def process( self.assertEqual(results.errors.element_type, str) +class CompositeTaggedOutputInferenceTest(unittest.TestCase): + """Tests for _infer_result_type when a composite PTransform returns + tagged outputs as a dict of fresh PCollections. + """ + def test_composite_returning_tagged_dict_preserves_existing_types(self): + """A composite that returns a dict of PCollections already typed by + DoOutputsTuple.__getitem__ should preserve those types through + _infer_result_type (the guard skips inference when element_type is set).""" + class MyComposite(beam.PTransform): + def expand(self, pcoll): + results = ( + pcoll + | beam.ParDo(self._MyDoFn()).with_outputs('errors', main='main')) + # Return a dict of the DoOutputsTuple's PCollections. + # These already have types set via __getitem__. + return {'main': results.main, 'errors': results.errors} + + @beam.typehints.with_output_types(int, errors=str) + class _MyDoFn(beam.DoFn): + def process(self, element): + if element < 0: + yield beam.pvalue.TaggedOutput('errors', f'Negative: {element}') + else: + yield element + + with beam.Pipeline() as p: + result = (p | beam.Create([-1, 0, 1, 2]) | MyComposite()) + + self.assertEqual(result['main'].element_type, int) + self.assertEqual(result['errors'].element_type, str) + + def test_composite_returning_tagged_dict_without_dofn_hints_is_any(self): + class MyComposite(beam.PTransform): + def expand(self, pcoll): + results = ( + pcoll + | beam.ParDo(self._MyDoFn()).with_outputs('errors', main='main')) + return {'main': results.main, 'errors': results.errors} + + class _MyDoFn(beam.DoFn): + def process(self, element): + if element < 0: + yield beam.pvalue.TaggedOutput('errors', f'Negative: {element}') + else: + yield element + + with beam.Pipeline() as p: + result = (p | beam.Create([-1, 0, 1, 2]) | MyComposite()) + + self.assertEqual(result['errors'].element_type, Any) + + def test_composite_pcollections_uses_tagged_type_hints(self): + """A composite that creates new PCollections (element_type=None) and + returns them as a dict should still get correct tagged types from + the type hints.""" + @beam.typehints.with_output_types(int, errors=str) + class MyComposite(beam.PTransform): + def expand(self, pcoll): + results = ( + pcoll + | beam.ParDo(self._MyDoFn()).with_outputs('errors', main='main')) + return { + p.tag if p.tag else results._main_tag: beam.pvalue.PCollection( + pcoll.pipeline, tag=p.tag) + for p in results + } + + class _MyDoFn(beam.DoFn): + def process( + self, element + ) -> Iterable[int | beam.TaggedOutput[Literal['errors'], str]]: + if element < 0: + yield beam.pvalue.TaggedOutput('errors', f'Negative: {element}') + else: + yield element + + with beam.Pipeline() as p: + result = (p | beam.Create([-1, 0, 1, 2]) | MyComposite()) + + self.assertEqual(result['main'].element_type, int) + self.assertEqual(result['errors'].element_type, str) + + if __name__ == '__main__': unittest.main()