Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion sdks/python/apache_beam/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -990,7 +990,11 @@ def expand(self, pcoll):
input_element_types_tuple[0] if len(input_element_types_tuple) == 1
else typehints.Union[input_element_types_tuple])
type_hints = transform.get_type_hints()
declared_output_type = type_hints.simple_output_type(transform.label)
if not result_pcollection.tag:
declared_output_type = type_hints.simple_output_type(transform.label)
else:
declared_output_type = type_hints.tagged_output_types().get(
result_pcollection.tag, typehints.Any)

if declared_output_type:
input_types = type_hints.input_types
Expand Down
84 changes: 84 additions & 0 deletions sdks/python/apache_beam/typehints/tagged_output_typehints_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def fn(element) -> int | TaggedOutput[Literal['errors'], str]:
import apache_beam as beam
from apache_beam.pvalue import TaggedOutput
from apache_beam.typehints import with_output_types
from apache_beam.typehints import Any
from apache_beam.typehints.decorators import IOTypeHints


Expand Down Expand Up @@ -352,5 +353,88 @@ def process(
self.assertEqual(results.errors.element_type, str)


class CompositeTaggedOutputInferenceTest(unittest.TestCase):
"""Tests for _infer_result_type when a composite PTransform returns
tagged outputs as a dict of fresh PCollections.
"""
def test_composite_returning_tagged_dict_preserves_existing_types(self):
"""A composite that returns a dict of PCollections already typed by
DoOutputsTuple.__getitem__ should preserve those types through
_infer_result_type (the guard skips inference when element_type is set)."""
class MyComposite(beam.PTransform):
def expand(self, pcoll):
results = (
pcoll
| beam.ParDo(self._MyDoFn()).with_outputs('errors', main='main'))
# Return a dict of the DoOutputsTuple's PCollections.
# These already have types set via __getitem__.
return {'main': results.main, 'errors': results.errors}

@beam.typehints.with_output_types(int, errors=str)
class _MyDoFn(beam.DoFn):
def process(self, element):
if element < 0:
yield beam.pvalue.TaggedOutput('errors', f'Negative: {element}')
else:
yield element

with beam.Pipeline() as p:
result = (p | beam.Create([-1, 0, 1, 2]) | MyComposite())

self.assertEqual(result['main'].element_type, int)
self.assertEqual(result['errors'].element_type, str)

def test_composite_returning_tagged_dict_without_dofn_hints_is_any(self):
class MyComposite(beam.PTransform):
def expand(self, pcoll):
results = (
pcoll
| beam.ParDo(self._MyDoFn()).with_outputs('errors', main='main'))
return {'main': results.main, 'errors': results.errors}

class _MyDoFn(beam.DoFn):
def process(self, element):
if element < 0:
yield beam.pvalue.TaggedOutput('errors', f'Negative: {element}')
else:
yield element

with beam.Pipeline() as p:
result = (p | beam.Create([-1, 0, 1, 2]) | MyComposite())

self.assertEqual(result['errors'].element_type, Any)

def test_composite_pcollections_uses_tagged_type_hints(self):
"""A composite that creates new PCollections (element_type=None) and
returns them as a dict should still get correct tagged types from
the type hints."""
@beam.typehints.with_output_types(int, errors=str)
class MyComposite(beam.PTransform):
def expand(self, pcoll):
results = (
pcoll
| beam.ParDo(self._MyDoFn()).with_outputs('errors', main='main'))
return {
p.tag if p.tag else results._main_tag: beam.pvalue.PCollection(
pcoll.pipeline, tag=p.tag)
for p in results
}

class _MyDoFn(beam.DoFn):
def process(
self, element
) -> Iterable[int | beam.TaggedOutput[Literal['errors'], str]]:
if element < 0:
yield beam.pvalue.TaggedOutput('errors', f'Negative: {element}')
else:
yield element

with beam.Pipeline() as p:
result = (p | beam.Create([-1, 0, 1, 2]) | MyComposite())

self.assertEqual(result['main'].element_type, int)
self.assertEqual(result['errors'].element_type, str)


if __name__ == '__main__':
unittest.main()
Loading