Skip to content

Commit dd9552c

Browse files
authored
fixed the beam yaml create issue by creating the full dict (#35438)
* fixed the beam yaml create issue by create the full dict * sorted the elements * fixed assert * optimized elements * fixed AssertEqual * fixed the lint
1 parent 2e313a9 commit dd9552c

File tree

2 files changed

+51
-4
lines changed

2 files changed

+51
-4
lines changed

sdks/python/apache_beam/yaml/tests/create.yaml

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ pipelines:
3030
- {element: 2}
3131
- {element: 3}
3232
- {element: 4}
33-
- {element: 5}
33+
- {element: 5}
3434

3535
# Simple Create with more complex beam row
3636
- pipeline:
@@ -64,3 +64,20 @@ pipelines:
6464
- {first: 0, second: [1,2,3]}
6565
- {first: 1, second: [4,5,6]}
6666
- {first: 2, second: [7,8,9]}
67+
68+
# Simple Create with element list
69+
- pipeline:
70+
type: chain
71+
transforms:
72+
- type: Create
73+
config:
74+
elements:
75+
- {sdk: MapReduce, year: 2004}
76+
- {sdk: Flume}
77+
- {sdk: MillWheel, year: 2008}
78+
- type: AssertEqual
79+
config:
80+
elements:
81+
- {sdk: MapReduce, year: 2004}
82+
- {sdk: Flume}
83+
- {sdk: MillWheel, year: 2008}

sdks/python/apache_beam/yaml/yaml_provider.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -785,9 +785,14 @@ def __init__(self, elements: Iterable[Any]):
785785
self._elements = elements
786786

787787
def expand(self, pcoll):
788+
def to_dict(row):
789+
# filter None when comparing
790+
temp_dict = {k: v for k, v in row._asdict().items() if v is not None}
791+
return dict(temp_dict.items())
792+
788793
return assert_that(
789-
pcoll | beam.Map(lambda row: beam.Row(**row._asdict())),
790-
equal_to(dicts_to_rows(self._elements)))
794+
pcoll | beam.Map(to_dict),
795+
equal_to([to_dict(e) for e in dicts_to_rows(self._elements)]))
791796

792797
@staticmethod
793798
def create(elements: Iterable[Any], reshuffle: Optional[bool] = True):
@@ -838,7 +843,32 @@ def create(elements: Iterable[Any], reshuffle: Optional[bool] = True):
838843
# not the intent.
839844
if not isinstance(elements, Iterable) or isinstance(elements, (dict, str)):
840845
raise TypeError('elements must be a list of elements')
841-
return beam.Create([element_to_rows(e) for e in elements],
846+
847+
# Check if elements have different keys
848+
updated_elements = elements
849+
if elements and all(isinstance(e, dict) for e in elements):
850+
keys = [set(e.keys()) for e in elements]
851+
if len(set.union(*keys)) > min(len(k) for k in keys):
852+
# Merge all dictionaries to get all possible keys
853+
all_keys = set()
854+
for element in elements:
855+
if isinstance(element, dict):
856+
all_keys.update(element.keys())
857+
858+
# Create a merged dictionary with all keys
859+
merged_dict = {}
860+
for key in all_keys:
861+
merged_dict[key] = None # Use None as a default value
862+
863+
# Update each element with the merged dictionary
864+
updated_elements = []
865+
for e in elements:
866+
if isinstance(e, dict):
867+
updated_elements.append({**merged_dict, **e})
868+
else:
869+
updated_elements.append(e)
870+
871+
return beam.Create([element_to_rows(e) for e in updated_elements],
842872
reshuffle=reshuffle is not False)
843873

844874
# Or should this be posargs, args?

0 commit comments

Comments
 (0)