Skip to content

Commit 701f93d

Browse files
committed
Fix issue with sunstream partition router picking extra fields
1 parent f525803 commit 701f93d

File tree

4 files changed

+169
-83
lines changed

4 files changed

+169
-83
lines changed

airbyte_cdk/sources/declarative/partition_routers/cartesian_product_stream_slicer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ def stream_slices(self) -> Iterable[StreamSlice]:
149149
for stream_slice_tuple in product:
150150
partition = dict(ChainMap(*[s.partition for s in stream_slice_tuple])) # type: ignore # ChainMap expects a MutableMapping[Never, Never] for reasons
151151
cursor_slices = [s.cursor_slice for s in stream_slice_tuple if s.cursor_slice]
152+
extra_fields = dict(ChainMap(*[s.extra_fields for s in stream_slice_tuple]))
152153
if len(cursor_slices) > 1:
153154
raise ValueError(
154155
f"There should only be a single cursor slice. Found {cursor_slices}"
@@ -157,7 +158,9 @@ def stream_slices(self) -> Iterable[StreamSlice]:
157158
cursor_slice = cursor_slices[0]
158159
else:
159160
cursor_slice = {}
160-
yield StreamSlice(partition=partition, cursor_slice=cursor_slice)
161+
yield StreamSlice(
162+
partition=partition, cursor_slice=cursor_slice, extra_fields=extra_fields
163+
)
161164

162165
def set_initial_state(self, stream_state: StreamState) -> None:
163166
"""
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
#
2+
# Copyright (c) 2025 Airbyte, Inc., all rights reserved.
3+
#
4+
5+
from typing import Any, Iterable, List, Mapping, Optional, Union
6+
7+
8+
from airbyte_cdk.models import SyncMode
9+
from airbyte_cdk.sources.declarative.declarative_stream import DeclarativeStream
10+
from airbyte_cdk.sources.declarative.incremental.per_partition_cursor import (
11+
StreamSlice,
12+
)
13+
from airbyte_cdk.sources.declarative.interpolation import InterpolatedString
14+
from airbyte_cdk.sources.streams.checkpoint import Cursor
15+
from airbyte_cdk.sources.types import Record
16+
17+
18+
class MockStream(DeclarativeStream):
19+
def __init__(self, slices, records, name, cursor_field="", cursor=None):
20+
self.config = {}
21+
self._slices = slices
22+
self._records = records
23+
self._stream_cursor_field = (
24+
InterpolatedString.create(cursor_field, parameters={})
25+
if isinstance(cursor_field, str)
26+
else cursor_field
27+
)
28+
self._name = name
29+
self._state = {"states": []}
30+
self._cursor = cursor
31+
32+
@property
33+
def name(self) -> str:
34+
return self._name
35+
36+
@property
37+
def primary_key(self) -> Optional[Union[str, List[str], List[List[str]]]]:
38+
return "id"
39+
40+
@property
41+
def state(self) -> Mapping[str, Any]:
42+
return self._state
43+
44+
@state.setter
45+
def state(self, value: Mapping[str, Any]) -> None:
46+
self._state = value
47+
48+
@property
49+
def is_resumable(self) -> bool:
50+
return bool(self._cursor)
51+
52+
def get_cursor(self) -> Optional[Cursor]:
53+
return self._cursor
54+
55+
def stream_slices(
56+
self,
57+
*,
58+
sync_mode: SyncMode,
59+
cursor_field: List[str] = None,
60+
stream_state: Mapping[str, Any] = None,
61+
) -> Iterable[Optional[StreamSlice]]:
62+
for s in self._slices:
63+
if isinstance(s, StreamSlice):
64+
yield s
65+
else:
66+
yield StreamSlice(partition=s, cursor_slice={})
67+
68+
def read_records(
69+
self,
70+
sync_mode: SyncMode,
71+
cursor_field: List[str] = None,
72+
stream_slice: Mapping[str, Any] = None,
73+
stream_state: Mapping[str, Any] = None,
74+
) -> Iterable[Mapping[str, Any]]:
75+
# The parent stream's records should always be read as full refresh
76+
assert sync_mode == SyncMode.full_refresh
77+
78+
if not stream_slice:
79+
result = self._records
80+
else:
81+
result = [
82+
Record(data=r, associated_slice=stream_slice, stream_name=self.name)
83+
for r in self._records
84+
if r["slice"] == stream_slice["slice"]
85+
]
86+
87+
yield from result
88+
89+
# Update the state only after reading the full slice
90+
cursor_field = self._stream_cursor_field.eval(config=self.config)
91+
if stream_slice and cursor_field and result:
92+
self._state["states"].append(
93+
{cursor_field: result[-1][cursor_field], "partition": stream_slice["slice"]}
94+
)
95+
96+
def get_json_schema(self) -> Mapping[str, Any]:
97+
return {}

unit_tests/sources/declarative/partition_routers/test_cartesian_product_partition_router.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,11 @@
1616
RequestOptionType,
1717
)
1818
from airbyte_cdk.sources.types import StreamSlice
19+
from airbyte_cdk.sources.declarative.partition_routers.substream_partition_router import (
20+
ParentStreamConfig,
21+
SubstreamPartitionRouter,
22+
)
23+
from .helpers import MockStream
1924

2025

2126
@pytest.mark.parametrize(
@@ -171,6 +176,68 @@ def test_substream_slicer(test_name, stream_slicers, expected_slices):
171176
assert slices == expected_slices
172177

173178

179+
@pytest.mark.parametrize(
180+
"test_name, stream_slicers, expected_slices",
181+
[
182+
(
183+
"test_single_stream_slicer",
184+
[
185+
SubstreamPartitionRouter(
186+
parent_stream_configs=[
187+
ParentStreamConfig(
188+
stream=MockStream(
189+
[{}],
190+
[
191+
{"a": {"b": 0}, "extra_field_key": "extra_field_value_0"},
192+
{"a": {"b": 1}, "extra_field_key": "extra_field_value_1"},
193+
{"a": {"c": 2}, "extra_field_key": "extra_field_value_2"},
194+
{"a": {"b": 3}, "extra_field_key": "extra_field_value_3"},
195+
],
196+
"first_stream",
197+
),
198+
parent_key="a/b",
199+
partition_field="first_stream_id",
200+
parameters={},
201+
config={},
202+
extra_fields=[["extra_field_key"]],
203+
)
204+
],
205+
parameters={},
206+
config={},
207+
),
208+
],
209+
[
210+
StreamSlice(
211+
partition={"first_stream_id": 0, "parent_slice": {}},
212+
cursor_slice={},
213+
extra_fields={"extra_field_key": "extra_field_value_0"},
214+
),
215+
StreamSlice(
216+
partition={"first_stream_id": 1, "parent_slice": {}},
217+
cursor_slice={},
218+
extra_fields={"extra_field_key": "extra_field_value_1"},
219+
),
220+
StreamSlice(
221+
partition={"first_stream_id": 3, "parent_slice": {}},
222+
cursor_slice={},
223+
extra_fields={"extra_field_key": "extra_field_value_3"},
224+
),
225+
],
226+
)
227+
],
228+
)
229+
def test_substream_slicer_with_extra_fields(test_name, stream_slicers, expected_slices):
230+
slicer = CartesianProductStreamSlicer(stream_slicers=stream_slicers, parameters={})
231+
slices = [s for s in slicer.stream_slices()]
232+
partitions = [s.partition for s in slices]
233+
expected_partitions = [s.partition for s in expected_slices]
234+
assert partitions == expected_partitions
235+
236+
extra_fields = [s.extra_fields for s in slices]
237+
expected_extra_fields = [s.extra_fields for s in expected_slices]
238+
assert extra_fields == expected_extra_fields
239+
240+
174241
def test_stream_slices_raises_exception_if_multiple_cursor_slice_components():
175242
stream_slicers = [
176243
DatetimeBasedCursor(

unit_tests/sources/declarative/partition_routers/test_substream_partition_router.py

Lines changed: 1 addition & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -62,88 +62,7 @@
6262
all_parent_data_with_cursor = (
6363
data_first_parent_slice_with_cursor + data_second_parent_slice_with_cursor
6464
)
65-
66-
67-
class MockStream(DeclarativeStream):
68-
def __init__(self, slices, records, name, cursor_field="", cursor=None):
69-
self.config = {}
70-
self._slices = slices
71-
self._records = records
72-
self._stream_cursor_field = (
73-
InterpolatedString.create(cursor_field, parameters={})
74-
if isinstance(cursor_field, str)
75-
else cursor_field
76-
)
77-
self._name = name
78-
self._state = {"states": []}
79-
self._cursor = cursor
80-
81-
@property
82-
def name(self) -> str:
83-
return self._name
84-
85-
@property
86-
def primary_key(self) -> Optional[Union[str, List[str], List[List[str]]]]:
87-
return "id"
88-
89-
@property
90-
def state(self) -> Mapping[str, Any]:
91-
return self._state
92-
93-
@state.setter
94-
def state(self, value: Mapping[str, Any]) -> None:
95-
self._state = value
96-
97-
@property
98-
def is_resumable(self) -> bool:
99-
return bool(self._cursor)
100-
101-
def get_cursor(self) -> Optional[Cursor]:
102-
return self._cursor
103-
104-
def stream_slices(
105-
self,
106-
*,
107-
sync_mode: SyncMode,
108-
cursor_field: List[str] = None,
109-
stream_state: Mapping[str, Any] = None,
110-
) -> Iterable[Optional[StreamSlice]]:
111-
for s in self._slices:
112-
if isinstance(s, StreamSlice):
113-
yield s
114-
else:
115-
yield StreamSlice(partition=s, cursor_slice={})
116-
117-
def read_records(
118-
self,
119-
sync_mode: SyncMode,
120-
cursor_field: List[str] = None,
121-
stream_slice: Mapping[str, Any] = None,
122-
stream_state: Mapping[str, Any] = None,
123-
) -> Iterable[Mapping[str, Any]]:
124-
# The parent stream's records should always be read as full refresh
125-
assert sync_mode == SyncMode.full_refresh
126-
127-
if not stream_slice:
128-
result = self._records
129-
else:
130-
result = [
131-
Record(data=r, associated_slice=stream_slice, stream_name=self.name)
132-
for r in self._records
133-
if r["slice"] == stream_slice["slice"]
134-
]
135-
136-
yield from result
137-
138-
# Update the state only after reading the full slice
139-
cursor_field = self._stream_cursor_field.eval(config=self.config)
140-
if stream_slice and cursor_field and result:
141-
self._state["states"].append(
142-
{cursor_field: result[-1][cursor_field], "partition": stream_slice["slice"]}
143-
)
144-
145-
def get_json_schema(self) -> Mapping[str, Any]:
146-
return {}
65+
from .helpers import MockStream
14766

14867

14968
class MockIncrementalStream(MockStream):

0 commit comments

Comments
 (0)