Skip to content

Commit f9eb050

Browse files
author
maxime.c
committed
fix remaining tests in test_per_partition_cursor_integration
1 parent 6c8771c commit f9eb050

File tree

2 files changed

+16
-54
lines changed

2 files changed

+16
-54
lines changed

airbyte_cdk/sources/declarative/incremental/concurrent_partition_cursor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ class ConcurrentPerPartitionCursor(Cursor):
5555
Manages state per partition when a stream has many partitions, preventing data loss or duplication.
5656
5757
Attributes:
58-
DEFAULT_MAX_PARTITIONS_NUMBER (int): Maximum number of partitions to retain in memory (default is 10,000).
58+
DEFAULT_MAX_PARTITIONS_NUMBER (int): Maximum number of partitions to retain in memory (default is 10,000). This limit needs to be higher than the number of threads we might enqueue (which is represented by ThreadPoolManager.DEFAULT_MAX_QUEUE_SIZE). If not, we could have partitions that have been generated and submitted to the ThreadPool but got deleted from the ConcurrentPerPartitionCursor and when closing them, it will generate KeyError.
5959
6060
- **Partition Limitation Logic**
6161
Ensures the number of tracked partitions does not exceed the specified limit to prevent memory overuse. Oldest partitions are removed when the limit is reached.

unit_tests/sources/declarative/incremental/test_per_partition_cursor_integration.py

Lines changed: 15 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -394,13 +394,13 @@ def test_substream_without_input_state():
394394
]
395395

396396

397-
def test_partition_limitation(caplog):
397+
def test_switch_to_global_limit(caplog):
398398
"""
399-
Test that when the number of partitions exceeds the maximum allowed limit in PerPartitionCursor,
400-
the oldest partitions are dropped, and the state is updated accordingly.
399+
Test that when the number of partitions exceeds the limit to switch to global state.
401400
402-
In this test, we set the maximum number of partitions to 2 and provide 3 partitions.
403-
We verify that the state only retains information for the two most recent partitions.
401+
In this test, we set the maximum number of partitions to 1 (not 2 because we evaluate this before generating a
402+
partition and the limit is not inclusive) and provide 3 partitions.
403+
We verify that the state switch to global.
404404
"""
405405
stream_name = "Rates"
406406

@@ -508,15 +508,15 @@ def test_partition_limitation(caplog):
508508
)
509509

510510
# Use caplog to capture logs
511-
with caplog.at_level(logging.WARNING, logger="airbyte"):
511+
with caplog.at_level(logging.INFO, logger="airbyte"):
512512
with patch.object(SimpleRetriever, "_read_pages", side_effect=records_list):
513-
with patch.object(ConcurrentPerPartitionCursor, "DEFAULT_MAX_PARTITIONS_NUMBER", 2):
513+
with patch.object(ConcurrentPerPartitionCursor, "SWITCH_TO_GLOBAL_LIMIT", 1):
514514
output = list(source.read(logger, {}, catalog, initial_state))
515515

516516
# Check if the warning was logged
517-
logged_messages = [record.message for record in caplog.records if record.levelname == "WARNING"]
518-
warning_message = 'The maximum number of partitions has been reached. Dropping the oldest partition: {"partition_field":"1"}. Over limit: 1.'
519-
assert warning_message in logged_messages
517+
logged_messages = [record.message for record in caplog.records if record.levelname == "INFO"]
518+
warning_message = "Exceeded the 'SWITCH_TO_GLOBAL_LIMIT' of"
519+
assert any(map(lambda message: warning_message in message, logged_messages))
520520

521521
final_state = [
522522
orjson.loads(orjson.dumps(message.state.stream.stream_state))
@@ -526,17 +526,7 @@ def test_partition_limitation(caplog):
526526
assert final_state[-1] == {
527527
"lookback_window": 1,
528528
"state": {"cursor_field": "2022-02-17"},
529-
"use_global_cursor": False,
530-
"states": [
531-
{
532-
"partition": {"partition_field": "2"},
533-
"cursor": {CURSOR_FIELD: "2022-01-16"},
534-
},
535-
{
536-
"partition": {"partition_field": "3"},
537-
"cursor": {CURSOR_FIELD: "2022-02-17"},
538-
},
539-
],
529+
"use_global_cursor": True,
540530
}
541531

542532

@@ -684,38 +674,10 @@ def test_perpartition_with_fallback(caplog):
684674
state=initial_state,
685675
)
686676

687-
# Use caplog to capture logs
688-
with caplog.at_level(logging.WARNING, logger="airbyte"):
689-
with patch.object(SimpleRetriever, "_read_pages", side_effect=records_list):
690-
with patch.object(ConcurrentPerPartitionCursor, "DEFAULT_MAX_PARTITIONS_NUMBER", 2):
691-
with patch.object(ConcurrentPerPartitionCursor, "SWITCH_TO_GLOBAL_LIMIT", 1):
692-
output = list(source.read(logger, {}, catalog, initial_state))
693-
694-
# Check if the warnings were logged
695-
logged_messages = [record.message for record in caplog.records if record.levelname == "WARNING"]
696-
warning_message = (
697-
"The maximum number of partitions has been reached. Dropping the oldest partition:"
698-
)
699-
expected_warning_over_limit_messages = [
700-
"Over limit: 1",
701-
"Over limit: 2",
702-
"Over limit: 3",
703-
]
704-
705-
for logged_message in logged_messages:
706-
assert warning_message in logged_message
707-
708-
for expected_warning_over_limit_message in expected_warning_over_limit_messages:
709-
assert (
710-
len(
711-
[
712-
logged_message
713-
for logged_message in logged_messages
714-
if expected_warning_over_limit_message in logged_message
715-
]
716-
)
717-
> 0
718-
)
677+
with patch.object(SimpleRetriever, "_read_pages", side_effect=records_list):
678+
with patch.object(ConcurrentPerPartitionCursor, "DEFAULT_MAX_PARTITIONS_NUMBER", 2):
679+
with patch.object(ConcurrentPerPartitionCursor, "SWITCH_TO_GLOBAL_LIMIT", 1):
680+
output = list(source.read(logger, {}, catalog, initial_state))
719681

720682
# Proceed with existing assertions
721683
final_state = [

0 commit comments

Comments
 (0)