Skip to content

Commit f7e5ee5

Browse files
authored
feat(yaml): add a simple schema unification for Flatten transform (#35728)
* feat(yaml): add schema unification for Flatten transform Implement schema merging for Flatten transform to handle PCollections with different schemas. The unified schema contains all fields from input PCollections, making fields optional to handle missing values. Added a test case to verify the behavior. * refactor(yaml_provider): improve schema unification by handling optional types Extract inner types from Optional when unifying schemas to properly handle type unions. Also improve code readability by breaking long lines and clarifying comments. * fix(yaml_provider): handle nested generic types in field type resolution Fix type resolution for nested generic types by properly extracting inner types when comparing field types. This ensures correct type hints are generated for optional fields in YAML provider. * fix tests * fix(yaml_provider): improve type handling and list conversion in schema unification Handle list types more carefully during schema unification to avoid unsupported Union types. Also ensure iterable values are properly converted to lists when needed for schema compatibility. * fix lint * refactor(yaml): remove unused import in yaml_provider and add flatten tests add comprehensive test cases for schema unification in Flatten transform * fix lint * fix lint * feat(yaml): add a simple schema unification for Flatten transform * fix lint * feat(yaml): refactor element unification logic into a standalone function * use AssertEqual
1 parent 132b24b commit f7e5ee5

File tree

3 files changed

+312
-3
lines changed

3 files changed

+312
-3
lines changed

sdks/python/apache_beam/yaml/yaml_provider.py

Lines changed: 100 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -756,6 +756,43 @@ def dicts_to_rows(o):
756756
return o
757757

758758

759+
def _unify_element_with_schema(element, target_schema):
760+
"""Convert an element to match the target schema, preserving existing
761+
fields only."""
762+
if target_schema is None:
763+
return element
764+
765+
# If element is already a named tuple, convert to dict first
766+
if hasattr(element, '_asdict'):
767+
element_dict = element._asdict()
768+
elif isinstance(element, dict):
769+
element_dict = element
770+
else:
771+
# This element is not a row, so it can't be unified with a
772+
# row schema.
773+
return element
774+
775+
# Create new element with only the fields that exist in the original
776+
# element plus None for fields that are expected but missing
777+
unified_dict = {}
778+
for field_name in target_schema._fields:
779+
if field_name in element_dict:
780+
value = element_dict[field_name]
781+
# Ensure the value matches the expected type
782+
# This is particularly important for list fields
783+
if value is not None and not isinstance(value, list) and hasattr(
784+
value, '__iter__') and not isinstance(
785+
value, (str, bytes)) and not hasattr(value, '_asdict'):
786+
# Convert iterables to lists if needed
787+
unified_dict[field_name] = list(value)
788+
else:
789+
unified_dict[field_name] = value
790+
else:
791+
unified_dict[field_name] = None
792+
793+
return target_schema(**unified_dict)
794+
795+
759796
class YamlProviders:
760797
class AssertEqual(beam.PTransform):
761798
"""Asserts that the input contains exactly the elements provided.
@@ -932,6 +969,48 @@ def __init__(self):
932969
# pylint: disable=useless-parent-delegation
933970
super().__init__()
934971

972+
def _merge_schemas(self, pcolls):
973+
"""Merge schemas from multiple PCollections to create a unified schema.
974+
975+
This function creates a unified schema that contains all fields from all
976+
input PCollections. Fields are made optional to handle missing values.
977+
If fields have different types, they are unified to Optional[Any].
978+
"""
979+
from apache_beam.typehints.schemas import named_fields_from_element_type
980+
981+
# Collect all schemas
982+
schemas = []
983+
for pcoll in pcolls:
984+
if hasattr(pcoll, 'element_type') and pcoll.element_type:
985+
try:
986+
fields = named_fields_from_element_type(pcoll.element_type)
987+
schemas.append(dict(fields))
988+
except (ValueError, TypeError):
989+
# If we can't extract schema, skip this PCollection
990+
continue
991+
992+
if not schemas:
993+
return None
994+
995+
# Merge all field names and types.
996+
all_field_names = set().union(*(s.keys() for s in schemas))
997+
unified_fields = {}
998+
for name in all_field_names:
999+
present_types = {s[name] for s in schemas if name in s}
1000+
if len(present_types) > 1:
1001+
unified_fields[name] = Optional[Any]
1002+
else:
1003+
unified_fields[name] = Optional[present_types.pop()]
1004+
1005+
# Create unified schema
1006+
if unified_fields:
1007+
from apache_beam.typehints.schemas import named_fields_to_schema
1008+
from apache_beam.typehints.schemas import named_tuple_from_schema
1009+
unified_schema = named_fields_to_schema(list(unified_fields.items()))
1010+
return named_tuple_from_schema(unified_schema)
1011+
1012+
return None
1013+
9351014
def expand(self, pcolls):
9361015
if isinstance(pcolls, beam.PCollection):
9371016
pipeline_arg = {}
@@ -942,7 +1021,27 @@ def expand(self, pcolls):
9421021
else:
9431022
pipeline_arg = {'pipeline': pcolls.pipeline}
9441023
pcolls = ()
945-
return pcolls | beam.Flatten(**pipeline_arg)
1024+
1025+
if not pcolls:
1026+
return pcolls | beam.Flatten(**pipeline_arg)
1027+
1028+
# Try to unify schemas
1029+
unified_schema = self._merge_schemas(pcolls)
1030+
1031+
if unified_schema is None:
1032+
# No schema unification needed, use standard flatten
1033+
return pcolls | beam.Flatten(**pipeline_arg)
1034+
1035+
# Apply schema unification to each PCollection before flattening.
1036+
unified_pcolls = []
1037+
for i, pcoll in enumerate(pcolls):
1038+
unified_pcoll = pcoll | f'UnifySchema{i}' >> beam.Map(
1039+
_unify_element_with_schema,
1040+
target_schema=unified_schema).with_output_types(unified_schema)
1041+
unified_pcolls.append(unified_pcoll)
1042+
1043+
# Flatten the unified PCollections
1044+
return unified_pcolls | beam.Flatten(**pipeline_arg)
9461045

9471046
class WindowInto(beam.PTransform):
9481047
# pylint: disable=line-too-long

sdks/python/apache_beam/yaml/yaml_transform.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -861,8 +861,8 @@ def preprocess_flattened_inputs(spec):
861861
def all_inputs(t):
862862
for key, values in t.get('input', {}).items():
863863
if isinstance(values, list):
864-
for ix, values in enumerate(values):
865-
yield f'{key}{ix}', values
864+
for ix, value in enumerate(values):
865+
yield f'{key}{ix}', value
866866
else:
867867
yield key, values
868868

sdks/python/apache_beam/yaml/yaml_transform_test.py

Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -477,6 +477,216 @@ def test_composite_resource_hints(self):
477477
b'1000000000',
478478
proto)
479479

480+
def test_flatten_unifies_schemas(self):
481+
with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
482+
pickle_library='cloudpickle')) as p:
483+
_ = p | YamlTransform(
484+
'''
485+
type: composite
486+
transforms:
487+
- type: Create
488+
name: Create1
489+
config:
490+
elements:
491+
- {ride_id: '1', passenger_count: 1}
492+
- {ride_id: '2', passenger_count: 2}
493+
- type: Create
494+
name: Create2
495+
config:
496+
elements:
497+
- {ride_id: '3'}
498+
- {ride_id: '4'}
499+
- type: Flatten
500+
input: [Create1, Create2]
501+
- type: AssertEqual
502+
input: Flatten
503+
config:
504+
elements:
505+
- {ride_id: '1', passenger_count: 1}
506+
- {ride_id: '2', passenger_count: 2}
507+
- {ride_id: '3'}
508+
- {ride_id: '4'}
509+
''')
510+
511+
def test_flatten_unifies_optional_fields(self):
512+
"""Test that Flatten correctly unifies schemas with optional fields."""
513+
with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
514+
pickle_library='cloudpickle')) as p:
515+
_ = p | YamlTransform(
516+
'''
517+
type: composite
518+
transforms:
519+
- type: Create
520+
name: Create1
521+
config:
522+
elements:
523+
- {id: '1', name: 'Alice', age: 30}
524+
- {id: '2', name: 'Bob', age: 25}
525+
- type: Create
526+
name: Create2
527+
config:
528+
elements:
529+
- {id: '3', name: 'Charlie'}
530+
- {id: '4', name: 'Diana'}
531+
- type: Flatten
532+
input: [Create1, Create2]
533+
- type: AssertEqual
534+
input: Flatten
535+
config:
536+
elements:
537+
- {id: '1', name: 'Alice', age: 30}
538+
- {id: '2', name: 'Bob', age: 25}
539+
- {id: '3', name: 'Charlie'}
540+
- {id: '4', name: 'Diana'}
541+
''')
542+
543+
def test_flatten_unifies_different_types(self):
544+
"""Test that Flatten correctly unifies schemas with different
545+
field types."""
546+
with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
547+
pickle_library='cloudpickle')) as p:
548+
_ = p | YamlTransform(
549+
'''
550+
type: composite
551+
transforms:
552+
- type: Create
553+
name: Create1
554+
config:
555+
elements:
556+
- {id: 1, value: 100}
557+
- {id: 2, value: 200}
558+
- type: Create
559+
name: Create2
560+
config:
561+
elements:
562+
- {id: '3', value: 'text'}
563+
- {id: '4', value: 'data'}
564+
- type: Flatten
565+
input: [Create1, Create2]
566+
- type: AssertEqual
567+
input: Flatten
568+
config:
569+
elements:
570+
- {id: 1, value: 100}
571+
- {id: 2, value: 200}
572+
- {id: '3', value: 'text'}
573+
- {id: '4', value: 'data'}
574+
''')
575+
576+
def test_flatten_unifies_list_fields(self):
577+
"""Test that Flatten correctly unifies schemas with list fields."""
578+
with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
579+
pickle_library='cloudpickle')) as p:
580+
_ = p | YamlTransform(
581+
'''
582+
type: composite
583+
transforms:
584+
- type: Create
585+
name: Create1
586+
config:
587+
elements:
588+
- {id: '1', tags: ['red', 'blue']}
589+
- {id: '2', tags: ['green']}
590+
- type: Create
591+
name: Create2
592+
config:
593+
elements:
594+
- {id: '3', tags: ['yellow', 'purple', 'orange']}
595+
- {id: '4', tags: []}
596+
- type: Flatten
597+
input: [Create1, Create2]
598+
- type: AssertEqual
599+
input: Flatten
600+
config:
601+
elements:
602+
- {id: '1', tags: ['red', 'blue']}
603+
- {id: '2', tags: ['green']}
604+
- {id: '3', tags: ['yellow', 'purple', 'orange']}
605+
- {id: '4', tags: []}
606+
''')
607+
608+
def test_flatten_unifies_with_missing_fields(self):
609+
"""Test that Flatten correctly unifies schemas when some inputs have
610+
missing fields."""
611+
with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
612+
pickle_library='cloudpickle')) as p:
613+
_ = p | YamlTransform(
614+
'''
615+
type: composite
616+
transforms:
617+
- type: Create
618+
name: Create1
619+
config:
620+
elements:
621+
- {id: '1', name: 'Alice', department: 'Engineering',
622+
salary: 75000}
623+
- {id: '2', name: 'Bob', department: 'Marketing',
624+
salary: 65000}
625+
- type: Create
626+
name: Create2
627+
config:
628+
elements:
629+
- {id: '3', name: 'Charlie', department: 'Sales'}
630+
- {id: '4', name: 'Diana'}
631+
- type: Flatten
632+
input: [Create1, Create2]
633+
- type: AssertEqual
634+
input: Flatten
635+
config:
636+
elements:
637+
- {id: '1', name: 'Alice', department: 'Engineering',
638+
salary: 75000}
639+
- {id: '2', name: 'Bob', department: 'Marketing',
640+
salary: 65000}
641+
- {id: '3', name: 'Charlie', department: 'Sales'}
642+
- {id: '4', name: 'Diana'}
643+
''')
644+
645+
def test_flatten_unifies_complex_mixed_schemas(self):
646+
"""Test that Flatten correctly unifies complex mixed
647+
schemas."""
648+
with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
649+
pickle_library='cloudpickle')) as p:
650+
_ = p | YamlTransform(
651+
'''
652+
type: composite
653+
transforms:
654+
- type: Create
655+
name: Create1
656+
config:
657+
elements:
658+
- {id: 1, name: 'Product A', price: 29.99,
659+
categories: ['electronics', 'gadgets']}
660+
- {id: 2, name: 'Product B', price: 15.50,
661+
categories: ['books']}
662+
- type: Create
663+
name: Create2
664+
config:
665+
elements:
666+
- {id: 3, name: 'Product C', categories: ['clothing']}
667+
- {id: 4, name: 'Product D', price: 99.99}
668+
- type: Create
669+
name: Create3
670+
config:
671+
elements:
672+
- {id: 5, name: 'Product E', price: 5.00,
673+
categories: []}
674+
- type: Flatten
675+
input: [Create1, Create2, Create3]
676+
- type: AssertEqual
677+
input: Flatten
678+
config:
679+
elements:
680+
- {id: 1, name: 'Product A', price: 29.99,
681+
categories: ['electronics', 'gadgets']}
682+
- {id: 2, name: 'Product B', price: 15.50,
683+
categories: ['books']}
684+
- {id: 3, name: 'Product C', categories: ['clothing']}
685+
- {id: 4, name: 'Product D', price: 99.99}
686+
- {id: 5, name: 'Product E', price: 5.00,
687+
categories: []}
688+
''')
689+
480690

481691
class ErrorHandlingTest(unittest.TestCase):
482692
def test_error_handling_outputs(self):

0 commit comments

Comments
 (0)