Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion airbyte_cdk/manifest_server/command_processor/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ def test_read(
"""
Test the read method of the source.
"""

test_read_handler = TestReader(
max_pages_per_slice=page_limit,
max_slices=slice_limit,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,21 @@
from airbyte_cdk.sources.types import Record, StreamSlice
from airbyte_cdk.utils.slice_hasher import SliceHasher


# For Connector Builder test read operations, we track the total number of records
# read for the stream at the global level so that we can stop reading early if we
# exceed the record limit
total_record_counter = 0
# read for the stream so that we can stop reading early if we exceed the record limit.
class RecordCounter:
def __init__(self) -> None:
self.total_record_counter = 0

def increment(self) -> None:
self.total_record_counter += 1

def reset(self) -> None:
self.total_record_counter = 0

def get_total_records(self) -> int:
return self.total_record_counter


class SchemaLoaderCachingDecorator(SchemaLoader):
Expand Down Expand Up @@ -51,6 +62,7 @@ def __init__(
self._retriever = retriever
self._message_repository = message_repository
self._max_records_limit = max_records_limit
self._record_counter = RecordCounter()

def create(self, stream_slice: StreamSlice) -> Partition:
return DeclarativePartition(
Expand All @@ -60,6 +72,7 @@ def create(self, stream_slice: StreamSlice) -> Partition:
message_repository=self._message_repository,
max_records_limit=self._max_records_limit,
stream_slice=stream_slice,
record_counter=self._record_counter,
)


Expand All @@ -72,6 +85,7 @@ def __init__(
message_repository: MessageRepository,
max_records_limit: Optional[int],
stream_slice: StreamSlice,
record_counter: RecordCounter,
):
self._stream_name = stream_name
self._schema_loader = schema_loader
Expand All @@ -80,17 +94,17 @@ def __init__(
self._max_records_limit = max_records_limit
self._stream_slice = stream_slice
self._hash = SliceHasher.hash(self._stream_name, self._stream_slice)
self._record_counter = record_counter

def read(self) -> Iterable[Record]:
if self._max_records_limit is not None:
global total_record_counter
if total_record_counter >= self._max_records_limit:
if self._record_counter.get_total_records() >= self._max_records_limit:
return
for stream_data in self._retriever.read_records(
self._schema_loader.get_json_schema(), self._stream_slice
):
if self._max_records_limit is not None:
if total_record_counter >= self._max_records_limit:
if self._record_counter.get_total_records() >= self._max_records_limit:
break

if isinstance(stream_data, Mapping):
Expand All @@ -108,7 +122,7 @@ def read(self) -> Iterable[Record]:
self._message_repository.emit_message(stream_data)

if self._max_records_limit is not None:
total_record_counter += 1
self._record_counter.increment()

def to_slice(self) -> Optional[Mapping[str, Any]]:
return self._stream_slice
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
from unittest import TestCase
from unittest.mock import Mock

# This allows for the global total_record_counter to be reset between tests
import airbyte_cdk.sources.declarative.stream_slicers.declarative_partition_generator as declarative_partition_generator
from airbyte_cdk.models import AirbyteLogMessage, AirbyteMessage, Level, Type
from airbyte_cdk.sources.declarative.retrievers import Retriever
from airbyte_cdk.sources.declarative.schema import InlineSchemaLoader
Expand Down Expand Up @@ -35,7 +33,7 @@ class StreamSlicerPartitionGeneratorTest(TestCase):
def test_given_multiple_slices_partition_generator_uses_the_same_retriever(self) -> None:
retriever = self._mock_retriever([])
message_repository = Mock(spec=MessageRepository)
partition_factory = declarative_partition_generator.DeclarativePartitionFactory(
partition_factory = DeclarativePartitionFactory(
_STREAM_NAME,
_SCHEMA_LOADER,
retriever,
Expand All @@ -50,7 +48,7 @@ def test_given_multiple_slices_partition_generator_uses_the_same_retriever(self)
def test_given_a_mapping_when_read_then_yield_record(self) -> None:
retriever = self._mock_retriever([_A_RECORD])
message_repository = Mock(spec=MessageRepository)
partition_factory = declarative_partition_generator.DeclarativePartitionFactory(
partition_factory = DeclarativePartitionFactory(
_STREAM_NAME,
_SCHEMA_LOADER,
retriever,
Expand All @@ -68,7 +66,7 @@ def test_given_a_mapping_when_read_then_yield_record(self) -> None:
def test_given_not_a_record_when_read_then_send_to_message_repository(self) -> None:
retriever = self._mock_retriever([_AIRBYTE_LOG_MESSAGE])
message_repository = Mock(spec=MessageRepository)
partition_factory = declarative_partition_generator.DeclarativePartitionFactory(
partition_factory = DeclarativePartitionFactory(
_STREAM_NAME,
_SCHEMA_LOADER,
retriever,
Expand All @@ -80,8 +78,6 @@ def test_given_not_a_record_when_read_then_send_to_message_repository(self) -> N
message_repository.emit_message.assert_called_once_with(_AIRBYTE_LOG_MESSAGE)

def test_max_records_reached_stops_reading(self) -> None:
declarative_partition_generator.total_record_counter = 0

expected_records = [
Record(data={"id": 1, "name": "Max"}, stream_name="stream_name"),
Record(data={"id": 1, "name": "Oscar"}, stream_name="stream_name"),
Expand All @@ -97,7 +93,7 @@ def test_max_records_reached_stops_reading(self) -> None:

retriever = self._mock_retriever(mock_records)
message_repository = Mock(spec=MessageRepository)
partition_factory = declarative_partition_generator.DeclarativePartitionFactory(
partition_factory = DeclarativePartitionFactory(
_STREAM_NAME,
_SCHEMA_LOADER,
retriever,
Expand All @@ -113,8 +109,6 @@ def test_max_records_reached_stops_reading(self) -> None:
assert actual_records == expected_records

def test_max_records_reached_on_previous_partition(self) -> None:
declarative_partition_generator.total_record_counter = 0

expected_records = [
Record(data={"id": 1, "name": "Max"}, stream_name="stream_name"),
Record(data={"id": 1, "name": "Oscar"}, stream_name="stream_name"),
Expand All @@ -128,7 +122,7 @@ def test_max_records_reached_on_previous_partition(self) -> None:

retriever = self._mock_retriever(mock_records)
message_repository = Mock(spec=MessageRepository)
partition_factory = declarative_partition_generator.DeclarativePartitionFactory(
partition_factory = DeclarativePartitionFactory(
_STREAM_NAME,
_SCHEMA_LOADER,
retriever,
Expand All @@ -151,6 +145,55 @@ def test_max_records_reached_on_previous_partition(self) -> None:
# called for the first partition read and not the second
retriever.read_records.assert_called_once()

def test_record_counter_isolation_between_different_factories(self) -> None:
"""Test that record counters are isolated between different DeclarativePartitionFactory instances."""

# Create mock records that exceed the limit
records = [
Record(data={"id": 1, "name": "Record1"}, stream_name="stream_name"),
Record(data={"id": 2, "name": "Record2"}, stream_name="stream_name"),
Record(
data={"id": 3, "name": "Record3"}, stream_name="stream_name"
), # Should be blocked by limit
]

# Create first factory with record limit of 2
retriever1 = self._mock_retriever(records)
message_repository1 = Mock(spec=MessageRepository)
factory1 = DeclarativePartitionFactory(
_STREAM_NAME,
_SCHEMA_LOADER,
retriever1,
message_repository1,
max_records_limit=2,
)

# First factory should read up to limit (2 records)
partition1 = factory1.create(_A_STREAM_SLICE)
first_factory_records = list(partition1.read())
assert len(first_factory_records) == 2

# Create second factory with same limit - should be independent
retriever2 = self._mock_retriever(records)
message_repository2 = Mock(spec=MessageRepository)
factory2 = DeclarativePartitionFactory(
_STREAM_NAME,
_SCHEMA_LOADER,
retriever2,
message_repository2,
max_records_limit=2,
)

# Second factory should also be able to read up to limit (2 records)
# This would fail before the fix because record counter was global
partition2 = factory2.create(_A_STREAM_SLICE)
second_factory_records = list(partition2.read())
assert len(second_factory_records) == 2

# Verify both retrievers were called (confirming isolation)
retriever1.read_records.assert_called_once()
retriever2.read_records.assert_called_once()

@staticmethod
def _mock_retriever(read_return_value: List[StreamData]) -> Mock:
retriever = Mock(spec=Retriever)
Expand Down
Loading