Skip to content

Commit d3e7fe2

Browse files
committed
Add parent state updates
1 parent 667700f commit d3e7fe2

File tree

1 file changed

+41
-8
lines changed

1 file changed

+41
-8
lines changed

airbyte_cdk/sources/declarative/incremental/concurrent_partition_cursor.py

Lines changed: 41 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,10 @@ def __init__(
9595
# the oldest partitions can be efficiently removed, maintaining the most recent partitions.
9696
self._cursor_per_partition: OrderedDict[str, ConcurrentCursor] = OrderedDict()
9797
self._semaphore_per_partition: OrderedDict[str, threading.Semaphore] = OrderedDict()
98+
99+
# Parent-state tracking: store each partition’s parent state in creation order
100+
self._partition_parent_state_map: OrderedDict[str, Mapping[str, Any]] = OrderedDict()
101+
98102
self._finished_partitions: set[str] = set()
99103
self._lock = threading.Lock()
100104
self._timer = Timer()
@@ -154,10 +158,32 @@ def close_partition(self, partition: Partition) -> None:
154158
and self._semaphore_per_partition[partition_key]._value == 0
155159
):
156160
self._update_global_cursor(cursor.state[self.cursor_field.cursor_field_key])
157-
self._emit_state_message()
161+
162+
self._check_and_update_parent_state()
163+
164+
self._emit_state_message()
158165

159166
self._semaphore_per_partition[partition_key].acquire()
160167

168+
def _check_and_update_parent_state(self) -> None:
169+
"""
170+
If all slices for the earliest partitions are closed, pop them from the left
171+
of _partition_parent_state_map and update _parent_state to the most recent popped.
172+
"""
173+
last_closed_state = None
174+
# We iterate in creation order (left to right) in the OrderedDict
175+
for p_key in list(self._partition_parent_state_map.keys()):
176+
# If this partition is not fully closed, stop
177+
if p_key not in self._finished_partitions or self._semaphore_per_partition[p_key]._value != 0:
178+
break
179+
# Otherwise, we pop from the left
180+
_, closed_parent_state = self._partition_parent_state_map.popitem(last=False)
181+
last_closed_state = closed_parent_state
182+
183+
# If we popped at least one partition, update the parent_state to that partition's parent state
184+
if last_closed_state is not None:
185+
self._parent_state = last_closed_state
186+
161187
def ensure_at_least_one_state_emitted(self) -> None:
162188
"""
163189
The platform expect to have at least one state message on successful syncs. Hence, whatever happens, we expect this method to be
@@ -202,32 +228,39 @@ def stream_slices(self) -> Iterable[StreamSlice]:
202228

203229
slices = self._partition_router.stream_slices()
204230
self._timer.start()
205-
for partition in slices:
206-
yield from self._generate_slices_from_partition(partition)
231+
for partition, last, parent_state in iterate_with_last_flag_and_state(
232+
slices, self._partition_router.get_stream_state
233+
):
234+
yield from self._generate_slices_from_partition(partition, parent_state)
207235

208-
def _generate_slices_from_partition(self, partition: StreamSlice) -> Iterable[StreamSlice]:
236+
def _generate_slices_from_partition(self, partition: StreamSlice, parent_state: Mapping[str, Any]) -> Iterable[StreamSlice]:
209237
# Ensure the maximum number of partitions is not exceeded
210238
self._ensure_partition_limit()
211239

240+
partition_key = self._to_partition_key(partition.partition)
241+
212242
cursor = self._cursor_per_partition.get(self._to_partition_key(partition.partition))
213243
if not cursor:
214244
cursor = self._create_cursor(
215245
self._global_cursor,
216246
self._lookback_window if self._global_cursor else 0,
217247
)
218248
with self._lock:
219-
self._cursor_per_partition[self._to_partition_key(partition.partition)] = cursor
220-
self._semaphore_per_partition[self._to_partition_key(partition.partition)] = (
249+
self._cursor_per_partition[partition_key] = cursor
250+
self._semaphore_per_partition[partition_key] = (
221251
threading.Semaphore(0)
222252
)
223253

254+
with self._lock:
255+
self._partition_parent_state_map[partition_key] = deepcopy(parent_state)
256+
224257
for cursor_slice, is_last_slice, _ in iterate_with_last_flag_and_state(
225258
cursor.stream_slices(),
226259
lambda: None,
227260
):
228-
self._semaphore_per_partition[self._to_partition_key(partition.partition)].release()
261+
self._semaphore_per_partition[partition_key].release()
229262
if is_last_slice:
230-
self._finished_partitions.add(self._to_partition_key(partition.partition))
263+
self._finished_partitions.add(partition_key)
231264
yield StreamSlice(
232265
partition=partition, cursor_slice=cursor_slice, extra_fields=partition.extra_fields
233266
)

0 commit comments

Comments
 (0)