Skip to content

Commit 26455ca

Browse files
committed
Add extra fields
1 parent cbf6328 commit 26455ca

File tree

1 file changed

+21
-1
lines changed

1 file changed

+21
-1
lines changed

airbyte_cdk/sources/declarative/partition_routers/grouping_partition_router.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ class GroupingPartitionRouter(PartitionRouter):
1717
1818
Attributes:
1919
group_size (int): The number of partitions to include in each group.
20-
underlying_partition_router (SinglePartitionRouter): The partition router whose output will be grouped.
20+
underlying_partition_router (PartitionRouter): The partition router whose output will be grouped.
2121
deduplicate (bool): If True, ensures unique partitions within each group by removing duplicates based on the partition key.
2222
config (Config): The connector configuration.
2323
parameters (Mapping[str, Any]): Additional parameters for interpolation and configuration.
@@ -66,13 +66,33 @@ def stream_slices(self) -> Iterable[StreamSlice]:
6666
yield self._create_grouped_slice(batch)
6767

6868
def _create_grouped_slice(self, batch: list[StreamSlice]) -> StreamSlice:
69+
"""
70+
Creates a grouped StreamSlice from a batch of partitions, aggregating extra fields into a dictionary with list values.
71+
72+
Args:
73+
batch (list[StreamSlice]): A list of StreamSlice objects to group.
74+
75+
Returns:
76+
StreamSlice: A single StreamSlice with combined partition and extra field values.
77+
"""
6978
# Combine partition values into a single dict with lists
7079
grouped_partition = {
7180
key: [p.partition.get(key) for p in batch] for key in batch[0].partition.keys()
7281
}
82+
83+
# Aggregate extra fields into a dict with list values
84+
extra_fields_dict = (
85+
{
86+
key: [p.extra_fields.get(key) for p in batch]
87+
for key in set().union(*(p.extra_fields.keys() for p in batch if p.extra_fields))
88+
}
89+
if any(p.extra_fields for p in batch)
90+
else {}
91+
)
7392
return StreamSlice(
7493
partition=grouped_partition,
7594
cursor_slice={}, # Cursor is managed by the underlying router or incremental sync
95+
extra_fields=extra_fields_dict,
7696
)
7797

7898
def get_request_params(

0 commit comments

Comments
 (0)