|
11 | 11 | from datetime import timedelta |
12 | 12 | from typing import Any, Callable, Iterable, List, Mapping, MutableMapping, Optional, TypeVar |
13 | 13 |
|
| 14 | +from airbyte_cdk.models import ( |
| 15 | + AirbyteStateBlob, |
| 16 | + AirbyteStateMessage, |
| 17 | + AirbyteStateType, |
| 18 | + AirbyteStreamState, |
| 19 | + StreamDescriptor, |
| 20 | +) |
14 | 21 | from airbyte_cdk.sources.connector_state_manager import ConnectorStateManager |
15 | 22 | from airbyte_cdk.sources.declarative.partition_routers.partition_router import PartitionRouter |
16 | 23 | from airbyte_cdk.sources.message import MessageRepository |
@@ -101,7 +108,7 @@ class ConcurrentPerPartitionCursor(Cursor): |
101 | 108 | Manages state per partition when a stream has many partitions, preventing data loss or duplication. |
102 | 109 |
|
103 | 110 | Attributes: |
104 | | - DEFAULT_MAX_PARTITIONS_NUMBER (int): Maximum number of partitions to retain in memory (default is 10,000). |
| 111 | + DEFAULT_MAX_PARTITIONS_NUMBER (int): Maximum number of partitions to retain in memory (default is 10,000). This limit needs to be higher than the number of threads we might enqueue (which is represented by ThreadPoolManager.DEFAULT_MAX_QUEUE_SIZE). If not, we could have partitions that have been generated and submitted to the ThreadPool but got deleted from the ConcurrentPerPartitionCursor and when closing them, it will generate KeyError. |
105 | 112 |
|
106 | 113 | - **Partition Limitation Logic** |
107 | 114 | Ensures the number of tracked partitions does not exceed the specified limit to prevent memory overuse. Oldest partitions are removed when the limit is reached. |
@@ -181,6 +188,7 @@ def __init__( |
181 | 188 |
|
182 | 189 | # FIXME this is a temporary field the time of the migration from declarative cursors to concurrent ones |
183 | 190 | self._attempt_to_create_cursor_if_not_provided = attempt_to_create_cursor_if_not_provided |
| 191 | + self._synced_some_data = False |
184 | 192 |
|
185 | 193 | @property |
186 | 194 | def cursor_field(self) -> CursorField: |
@@ -221,8 +229,8 @@ def close_partition(self, partition: Partition) -> None: |
221 | 229 | with self._lock: |
222 | 230 | self._semaphore_per_partition[partition_key].acquire() |
223 | 231 | if not self._use_global_cursor: |
224 | | - self._cursor_per_partition[partition_key].close_partition(partition=partition) |
225 | 232 | cursor = self._cursor_per_partition[partition_key] |
| 233 | + cursor.close_partition(partition=partition) |
226 | 234 | if ( |
227 | 235 | partition_key in self._partitions_done_generating_stream_slices |
228 | 236 | and self._semaphore_per_partition[partition_key]._value == 0 |
@@ -266,8 +274,10 @@ def ensure_at_least_one_state_emitted(self) -> None: |
266 | 274 | if not any( |
267 | 275 | semaphore_item[1]._value for semaphore_item in self._semaphore_per_partition.items() |
268 | 276 | ): |
269 | | - self._global_cursor = self._new_global_cursor |
270 | | - self._lookback_window = self._timer.finish() |
| 277 | + if self._synced_some_data: |
| 278 | + # we only update those if we actually synced some data |
| 279 | + self._global_cursor = self._new_global_cursor |
| 280 | + self._lookback_window = self._timer.finish() |
271 | 281 | self._parent_state = self._partition_router.get_stream_state() |
272 | 282 | self._emit_state_message(throttle=False) |
273 | 283 |
|
@@ -475,9 +485,6 @@ def _set_initial_state(self, stream_state: StreamState) -> None: |
475 | 485 | if stream_state.get("parent_state"): |
476 | 486 | self._parent_state = stream_state["parent_state"] |
477 | 487 |
|
478 | | - # Set parent state for partition routers based on parent streams |
479 | | - self._partition_router.set_initial_state(stream_state) |
480 | | - |
481 | 488 | def _set_global_state(self, stream_state: Mapping[str, Any]) -> None: |
482 | 489 | """ |
483 | 490 | Initializes the global cursor state from the provided stream state. |
@@ -511,6 +518,7 @@ def observe(self, record: Record) -> None: |
511 | 518 | except ValueError: |
512 | 519 | return |
513 | 520 |
|
| 521 | + self._synced_some_data = True |
514 | 522 | record_cursor = self._connector_state_converter.output_format( |
515 | 523 | self._connector_state_converter.parse_value(record_cursor_value) |
516 | 524 | ) |
@@ -594,3 +602,45 @@ def _get_cursor(self, record: Record) -> ConcurrentCursor: |
594 | 602 |
|
595 | 603 | def limit_reached(self) -> bool: |
596 | 604 | return self._number_of_partitions > self.SWITCH_TO_GLOBAL_LIMIT |
| 605 | + |
| 606 | + @staticmethod |
| 607 | + def get_parent_state( |
| 608 | + stream_state: Optional[StreamState], parent_stream_name: str |
| 609 | + ) -> Optional[AirbyteStateMessage]: |
| 610 | + if not stream_state: |
| 611 | + return None |
| 612 | + |
| 613 | + if "parent_state" not in stream_state: |
| 614 | + logger.warning( |
| 615 | + f"Trying to get_parent_state for stream `{parent_stream_name}` when there are not parent state in the state" |
| 616 | + ) |
| 617 | + return None |
| 618 | + elif parent_stream_name not in stream_state["parent_state"]: |
| 619 | + logger.info( |
| 620 | + f"Could not find parent state for stream `{parent_stream_name}`. On parents available are {list(stream_state['parent_state'].keys())}" |
| 621 | + ) |
| 622 | + return None |
| 623 | + |
| 624 | + return AirbyteStateMessage( |
| 625 | + type=AirbyteStateType.STREAM, |
| 626 | + stream=AirbyteStreamState( |
| 627 | + stream_descriptor=StreamDescriptor(parent_stream_name, None), |
| 628 | + stream_state=AirbyteStateBlob(stream_state["parent_state"][parent_stream_name]), |
| 629 | + ), |
| 630 | + ) |
| 631 | + |
| 632 | + @staticmethod |
| 633 | + def get_global_state( |
| 634 | + stream_state: Optional[StreamState], parent_stream_name: str |
| 635 | + ) -> Optional[AirbyteStateMessage]: |
| 636 | + return ( |
| 637 | + AirbyteStateMessage( |
| 638 | + type=AirbyteStateType.STREAM, |
| 639 | + stream=AirbyteStreamState( |
| 640 | + stream_descriptor=StreamDescriptor(parent_stream_name, None), |
| 641 | + stream_state=AirbyteStateBlob(stream_state["state"]), |
| 642 | + ), |
| 643 | + ) |
| 644 | + if stream_state and "state" in stream_state |
| 645 | + else None |
| 646 | + ) |
0 commit comments