|
9 | 9 | from collections import OrderedDict |
10 | 10 | from copy import deepcopy |
11 | 11 | from datetime import timedelta |
12 | | -from typing import Any, Callable, Iterable, Mapping, MutableMapping, Optional |
| 12 | +from typing import Any, Callable, Iterable, List, Mapping, MutableMapping, Optional |
13 | 13 |
|
14 | 14 | from airbyte_cdk.sources.connector_state_manager import ConnectorStateManager |
15 | 15 | from airbyte_cdk.sources.declarative.incremental.global_substream_cursor import ( |
@@ -66,8 +66,8 @@ class ConcurrentPerPartitionCursor(Cursor): |
66 | 66 | _GLOBAL_STATE_KEY = "state" |
67 | 67 | _PERPARTITION_STATE_KEY = "states" |
68 | 68 | _IS_PARTITION_DUPLICATION_LOGGED = False |
69 | | - _KEY = 0 |
70 | | - _VALUE = 1 |
| 69 | + _PARENT_STATE = 0 |
| 70 | + _GENERATION_SEQUENCE = 1 |
71 | 71 |
|
72 | 72 | def __init__( |
73 | 73 | self, |
@@ -99,19 +99,29 @@ def __init__( |
99 | 99 | self._semaphore_per_partition: OrderedDict[str, threading.Semaphore] = OrderedDict() |
100 | 100 |
|
101 | 101 | # Parent-state tracking: store each partition’s parent state in creation order |
102 | | - self._partition_parent_state_map: OrderedDict[str, Mapping[str, Any]] = OrderedDict() |
| 102 | + self._partition_parent_state_map: OrderedDict[str, tuple[Mapping[str, Any], int]] = ( |
| 103 | + OrderedDict() |
| 104 | + ) |
| 105 | + self._parent_state: Optional[StreamState] = None |
| 106 | + |
| 107 | + # Tracks when the last slice for partition is emitted |
| 108 | + self._partitions_done_generating_stream_slices: set[str] = set() |
| 109 | + # Used to track the index of partitions that are not closed yet |
| 110 | + self._processing_partitions_indexes: List[int] = list() |
| 111 | + self._generated_partitions_count: int = 0 |
| 112 | + # Dictionary to map partition keys to their index |
| 113 | + self._partition_key_to_index: dict[str, int] = {} |
103 | 114 |
|
104 | | - self._finished_partitions: set[str] = set() |
105 | 115 | self._lock = threading.Lock() |
106 | | - self._timer = Timer() |
107 | | - self._new_global_cursor: Optional[StreamState] = None |
108 | 116 | self._lookback_window: int = 0 |
109 | | - self._parent_state: Optional[StreamState] = None |
| 117 | + self._new_global_cursor: Optional[StreamState] = None |
110 | 118 | self._number_of_partitions: int = 0 |
111 | 119 | self._use_global_cursor: bool = use_global_cursor |
112 | 120 | self._partition_serializer = PerPartitionKeySerializer() |
| 121 | + |
113 | 122 | # Track the last time a state message was emitted |
114 | 123 | self._last_emission_time: float = 0.0 |
| 124 | + self._timer = Timer() |
115 | 125 |
|
116 | 126 | self._set_initial_state(stream_state) |
117 | 127 |
|
@@ -157,60 +167,37 @@ def close_partition(self, partition: Partition) -> None: |
157 | 167 | self._cursor_per_partition[partition_key].close_partition(partition=partition) |
158 | 168 | cursor = self._cursor_per_partition[partition_key] |
159 | 169 | if ( |
160 | | - partition_key in self._finished_partitions |
| 170 | + partition_key in self._partitions_done_generating_stream_slices |
161 | 171 | and self._semaphore_per_partition[partition_key]._value == 0 |
162 | 172 | ): |
163 | 173 | self._update_global_cursor(cursor.state[self.cursor_field.cursor_field_key]) |
164 | 174 |
|
| 175 | + # Clean up the partition if it is fully processed |
| 176 | + self._cleanup_if_done(partition_key) |
| 177 | + |
165 | 178 | self._check_and_update_parent_state() |
166 | 179 |
|
167 | 180 | self._emit_state_message() |
168 | 181 |
|
169 | 182 | def _check_and_update_parent_state(self) -> None: |
170 | | - """ |
171 | | - Pop the leftmost partition state from _partition_parent_state_map only if |
172 | | - *all partitions* up to (and including) that partition key in _semaphore_per_partition |
173 | | - are fully finished (i.e. in _finished_partitions and semaphore._value == 0). |
174 | | - Additionally, delete finished semaphores with a value of 0 to free up memory, |
175 | | - as they are only needed to track errors and completion status. |
176 | | - """ |
177 | 183 | last_closed_state = None |
178 | 184 |
|
179 | 185 | while self._partition_parent_state_map: |
180 | | - # Look at the earliest partition key in creation order |
181 | | - earliest_key = next(iter(self._partition_parent_state_map)) |
182 | | - |
183 | | - # Verify ALL partitions from the left up to earliest_key are finished |
184 | | - all_left_finished = True |
185 | | - for p_key, sem in list( |
186 | | - self._semaphore_per_partition.items() |
187 | | - ): # Use list to allow modification during iteration |
188 | | - # If any earlier partition is still not finished, we must stop |
189 | | - if p_key not in self._finished_partitions or sem._value != 0: |
190 | | - all_left_finished = False |
191 | | - break |
192 | | - # Once we've reached earliest_key in the semaphore order, we can stop checking |
193 | | - if p_key == earliest_key: |
194 | | - break |
195 | | - |
196 | | - # If the partitions up to earliest_key are not all finished, break the while-loop |
197 | | - if not all_left_finished: |
198 | | - break |
| 186 | + earliest_key, (candidate_state, candidate_seq) = next( |
| 187 | + iter(self._partition_parent_state_map.items()) |
| 188 | + ) |
199 | 189 |
|
200 | | - # Pop the leftmost entry from parent-state map |
201 | | - _, closed_parent_state = self._partition_parent_state_map.popitem(last=False) |
202 | | - last_closed_state = closed_parent_state |
| 190 | + # if any partition that started <= candidate_seq is still open, we must wait |
| 191 | + if ( |
| 192 | + self._processing_partitions_indexes |
| 193 | + and self._processing_partitions_indexes[0] <= candidate_seq |
| 194 | + ): |
| 195 | + break |
203 | 196 |
|
204 | | - # Clean up finished semaphores with value 0 up to and including earliest_key |
205 | | - for p_key in list(self._semaphore_per_partition.keys()): |
206 | | - sem = self._semaphore_per_partition[p_key] |
207 | | - if p_key in self._finished_partitions and sem._value == 0: |
208 | | - del self._semaphore_per_partition[p_key] |
209 | | - logger.debug(f"Deleted finished semaphore for partition {p_key} with value 0") |
210 | | - if p_key == earliest_key: |
211 | | - break |
| 197 | + # safe to pop |
| 198 | + self._partition_parent_state_map.popitem(last=False) |
| 199 | + last_closed_state = candidate_state |
212 | 200 |
|
213 | | - # Update _parent_state if we popped at least one partition |
214 | 201 | if last_closed_state is not None: |
215 | 202 | self._parent_state = last_closed_state |
216 | 203 |
|
@@ -289,26 +276,32 @@ def _generate_slices_from_partition( |
289 | 276 | if not self._IS_PARTITION_DUPLICATION_LOGGED: |
290 | 277 | logger.warning(f"Partition duplication detected for stream {self._stream_name}") |
291 | 278 | self._IS_PARTITION_DUPLICATION_LOGGED = True |
| 279 | + return |
292 | 280 | else: |
293 | 281 | self._semaphore_per_partition[partition_key] = threading.Semaphore(0) |
294 | 282 |
|
295 | 283 | with self._lock: |
| 284 | + seq = self._generated_partitions_count |
| 285 | + self._generated_partitions_count += 1 |
| 286 | + self._processing_partitions_indexes.append(seq) |
| 287 | + self._partition_key_to_index[partition_key] = seq |
| 288 | + |
296 | 289 | if ( |
297 | 290 | len(self._partition_parent_state_map) == 0 |
298 | 291 | or self._partition_parent_state_map[ |
299 | 292 | next(reversed(self._partition_parent_state_map)) |
300 | | - ] |
| 293 | + ][self._PARENT_STATE] |
301 | 294 | != parent_state |
302 | 295 | ): |
303 | | - self._partition_parent_state_map[partition_key] = deepcopy(parent_state) |
| 296 | + self._partition_parent_state_map[partition_key] = (deepcopy(parent_state), seq) |
304 | 297 |
|
305 | 298 | for cursor_slice, is_last_slice, _ in iterate_with_last_flag_and_state( |
306 | 299 | cursor.stream_slices(), |
307 | 300 | lambda: None, |
308 | 301 | ): |
309 | 302 | self._semaphore_per_partition[partition_key].release() |
310 | 303 | if is_last_slice: |
311 | | - self._finished_partitions.add(partition_key) |
| 304 | + self._partitions_done_generating_stream_slices.add(partition_key) |
312 | 305 | yield StreamSlice( |
313 | 306 | partition=partition, cursor_slice=cursor_slice, extra_fields=partition.extra_fields |
314 | 307 | ) |
@@ -338,14 +331,11 @@ def _ensure_partition_limit(self) -> None: |
338 | 331 | while len(self._cursor_per_partition) > self.DEFAULT_MAX_PARTITIONS_NUMBER - 1: |
339 | 332 | # Try removing finished partitions first |
340 | 333 | for partition_key in list(self._cursor_per_partition.keys()): |
341 | | - if partition_key in self._finished_partitions and ( |
342 | | - partition_key not in self._semaphore_per_partition |
343 | | - or self._semaphore_per_partition[partition_key]._value == 0 |
344 | | - ): |
| 334 | + if partition_key not in self._partition_key_to_index: |
345 | 335 | oldest_partition = self._cursor_per_partition.pop( |
346 | 336 | partition_key |
347 | 337 | ) # Remove the oldest partition |
348 | | - logger.warning( |
| 338 | + logger.debug( |
349 | 339 | f"The maximum number of partitions has been reached. Dropping the oldest finished partition: {oldest_partition}. Over limit: {self._number_of_partitions - self.DEFAULT_MAX_PARTITIONS_NUMBER}." |
350 | 340 | ) |
351 | 341 | break |
@@ -474,6 +464,25 @@ def _update_global_cursor(self, value: Any) -> None: |
474 | 464 | ): |
475 | 465 | self._new_global_cursor = {self.cursor_field.cursor_field_key: copy.deepcopy(value)} |
476 | 466 |
|
| 467 | + def _cleanup_if_done(self, partition_key: str) -> None: |
| 468 | + """ |
| 469 | + Free every in-memory structure that belonged to a completed partition: |
| 470 | + cursor, semaphore, flag inside `_finished_partitions` |
| 471 | + """ |
| 472 | + if not ( |
| 473 | + partition_key in self._partitions_done_generating_stream_slices |
| 474 | + and self._semaphore_per_partition[partition_key]._value == 0 |
| 475 | + ): |
| 476 | + return |
| 477 | + |
| 478 | + self._semaphore_per_partition.pop(partition_key, None) |
| 479 | + self._partitions_done_generating_stream_slices.discard(partition_key) |
| 480 | + |
| 481 | + seq = self._partition_key_to_index.pop(partition_key) |
| 482 | + self._processing_partitions_indexes.remove(seq) |
| 483 | + |
| 484 | + logger.debug(f"Partition {partition_key} fully processed and cleaned up.") |
| 485 | + |
477 | 486 | def _to_partition_key(self, partition: Mapping[str, Any]) -> str: |
478 | 487 | return self._partition_serializer.to_partition_key(partition) |
479 | 488 |
|
|
0 commit comments