Skip to content

Commit 4ffa3d8

Browse files
committed
do an actual fix
1 parent 9eb9ec6 commit 4ffa3d8

File tree

4 files changed

+26
-80
lines changed

4 files changed

+26
-80
lines changed

airbyte_cdk/manifest_server/command_processor/processor.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
)
88
from fastapi import HTTPException
99

10-
import airbyte_cdk.sources.declarative.stream_slicers.declarative_partition_generator
1110
from airbyte_cdk.connector_builder.models import StreamRead
1211
from airbyte_cdk.connector_builder.test_reader import TestReader
1312
from airbyte_cdk.entrypoint import AirbyteEntrypoint
@@ -42,13 +41,6 @@ def test_read(
4241
"""
4342
Test the read method of the source.
4443
"""
45-
46-
# HACK: reset total_record_counter
47-
# DeclarativePartition defines total_record_counter as a global variable, which keeps around the record count
48-
# across multiple test_read calls, even if the source is different. This is a hack to reset the counter for
49-
# each test_read call.
50-
airbyte_cdk.sources.declarative.stream_slicers.declarative_partition_generator.total_record_counter = 0
51-
5244
test_read_handler = TestReader(
5345
max_pages_per_slice=page_limit,
5446
max_slices=slice_limit,

airbyte_cdk/sources/declarative/stream_slicers/declarative_partition_generator.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,21 @@
1414
from airbyte_cdk.sources.types import Record, StreamSlice
1515
from airbyte_cdk.utils.slice_hasher import SliceHasher
1616

17+
1718
# For Connector Builder test read operations, we track the total number of records
18-
# read for the stream at the global level so that we can stop reading early if we
19-
# exceed the record limit
20-
total_record_counter = 0
19+
# read for the stream so that we can stop reading early if we exceed the record limit.
20+
class RecordCounter:
21+
def __init__(self):
22+
self.total_record_counter = 0
23+
24+
def increment(self):
25+
self.total_record_counter += 1
26+
27+
def get_total_records(self) -> int:
28+
return self.total_record_counter
29+
30+
def reset(self):
31+
self.total_record_counter = 0
2132

2233

2334
class SchemaLoaderCachingDecorator(SchemaLoader):
@@ -51,6 +62,7 @@ def __init__(
5162
self._retriever = retriever
5263
self._message_repository = message_repository
5364
self._max_records_limit = max_records_limit
65+
self._record_counter = RecordCounter()
5466

5567
def create(self, stream_slice: StreamSlice) -> Partition:
5668
return DeclarativePartition(
@@ -60,6 +72,7 @@ def create(self, stream_slice: StreamSlice) -> Partition:
6072
message_repository=self._message_repository,
6173
max_records_limit=self._max_records_limit,
6274
stream_slice=stream_slice,
75+
record_counter=self._record_counter,
6376
)
6477

6578

@@ -72,6 +85,7 @@ def __init__(
7285
message_repository: MessageRepository,
7386
max_records_limit: Optional[int],
7487
stream_slice: StreamSlice,
88+
record_counter: RecordCounter,
7589
):
7690
self._stream_name = stream_name
7791
self._schema_loader = schema_loader
@@ -80,17 +94,17 @@ def __init__(
8094
self._max_records_limit = max_records_limit
8195
self._stream_slice = stream_slice
8296
self._hash = SliceHasher.hash(self._stream_name, self._stream_slice)
97+
self._record_counter = record_counter
8398

8499
def read(self) -> Iterable[Record]:
85100
if self._max_records_limit is not None:
86-
global total_record_counter
87-
if total_record_counter >= self._max_records_limit:
101+
if self._record_counter.get_total_records() >= self._max_records_limit:
88102
return
89103
for stream_data in self._retriever.read_records(
90104
self._schema_loader.get_json_schema(), self._stream_slice
91105
):
92106
if self._max_records_limit is not None:
93-
if total_record_counter >= self._max_records_limit:
107+
if self._record_counter.get_total_records() >= self._max_records_limit:
94108
break
95109

96110
if isinstance(stream_data, Mapping):
@@ -108,7 +122,7 @@ def read(self) -> Iterable[Record]:
108122
self._message_repository.emit_message(stream_data)
109123

110124
if self._max_records_limit is not None:
111-
total_record_counter += 1
125+
self._record_counter.increment()
112126

113127
def to_slice(self) -> Optional[Mapping[str, Any]]:
114128
return self._stream_slice

unit_tests/manifest_server/command_processor/test_processor.py

Lines changed: 0 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -308,57 +308,3 @@ def test_discover_with_trace_error(self, command_processor, sample_config):
308308
# Verify exception is raised
309309
with pytest.raises(HTTPException):
310310
command_processor.discover(sample_config)
311-
312-
def test_test_read_resets_global_record_counter(
313-
self, command_processor, sample_config, sample_catalog
314-
):
315-
"""Test that test_read resets the global total_record_counter between calls."""
316-
from airbyte_cdk.sources.declarative import stream_slicers
317-
318-
# Mock the TestReader
319-
with patch(
320-
"airbyte_cdk.manifest_server.command_processor.processor.TestReader"
321-
) as mock_test_reader_class:
322-
mock_test_reader_instance = Mock()
323-
mock_test_reader_class.return_value = mock_test_reader_instance
324-
mock_stream_read = Mock()
325-
mock_test_reader_instance.run_test_read.return_value = mock_stream_read
326-
327-
# Set initial counter value to simulate previous test_read execution
328-
stream_slicers.declarative_partition_generator.total_record_counter = 100
329-
330-
# Execute test_read
331-
result1 = command_processor.test_read(
332-
config=sample_config,
333-
catalog=sample_catalog,
334-
state=[],
335-
record_limit=50,
336-
page_limit=3,
337-
slice_limit=7,
338-
)
339-
340-
# Verify counter was reset to 0 before the test_read
341-
assert stream_slicers.declarative_partition_generator.total_record_counter == 0
342-
343-
# Set counter again to simulate state from first call
344-
stream_slicers.declarative_partition_generator.total_record_counter = 200
345-
346-
# Execute another test_read
347-
result2 = command_processor.test_read(
348-
config=sample_config,
349-
catalog=sample_catalog,
350-
state=[],
351-
record_limit=25,
352-
page_limit=2,
353-
slice_limit=5,
354-
)
355-
356-
# Verify counter was reset again
357-
assert stream_slicers.declarative_partition_generator.total_record_counter == 0
358-
359-
# Verify both calls returned the expected results
360-
assert result1 == mock_stream_read
361-
assert result2 == mock_stream_read
362-
363-
# Verify TestReader was called twice
364-
assert mock_test_reader_class.call_count == 2

unit_tests/sources/declarative/stream_slicers/test_declarative_partition_generator.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44
from unittest import TestCase
55
from unittest.mock import Mock
66

7-
# This allows for the global total_record_counter to be reset between tests
8-
import airbyte_cdk.sources.declarative.stream_slicers.declarative_partition_generator as declarative_partition_generator
97
from airbyte_cdk.models import AirbyteLogMessage, AirbyteMessage, Level, Type
108
from airbyte_cdk.sources.declarative.retrievers import Retriever
119
from airbyte_cdk.sources.declarative.schema import InlineSchemaLoader
@@ -35,7 +33,7 @@ class StreamSlicerPartitionGeneratorTest(TestCase):
3533
def test_given_multiple_slices_partition_generator_uses_the_same_retriever(self) -> None:
3634
retriever = self._mock_retriever([])
3735
message_repository = Mock(spec=MessageRepository)
38-
partition_factory = declarative_partition_generator.DeclarativePartitionFactory(
36+
partition_factory = DeclarativePartitionFactory(
3937
_STREAM_NAME,
4038
_SCHEMA_LOADER,
4139
retriever,
@@ -50,7 +48,7 @@ def test_given_multiple_slices_partition_generator_uses_the_same_retriever(self)
5048
def test_given_a_mapping_when_read_then_yield_record(self) -> None:
5149
retriever = self._mock_retriever([_A_RECORD])
5250
message_repository = Mock(spec=MessageRepository)
53-
partition_factory = declarative_partition_generator.DeclarativePartitionFactory(
51+
partition_factory = DeclarativePartitionFactory(
5452
_STREAM_NAME,
5553
_SCHEMA_LOADER,
5654
retriever,
@@ -68,7 +66,7 @@ def test_given_a_mapping_when_read_then_yield_record(self) -> None:
6866
def test_given_not_a_record_when_read_then_send_to_message_repository(self) -> None:
6967
retriever = self._mock_retriever([_AIRBYTE_LOG_MESSAGE])
7068
message_repository = Mock(spec=MessageRepository)
71-
partition_factory = declarative_partition_generator.DeclarativePartitionFactory(
69+
partition_factory = DeclarativePartitionFactory(
7270
_STREAM_NAME,
7371
_SCHEMA_LOADER,
7472
retriever,
@@ -80,8 +78,6 @@ def test_given_not_a_record_when_read_then_send_to_message_repository(self) -> N
8078
message_repository.emit_message.assert_called_once_with(_AIRBYTE_LOG_MESSAGE)
8179

8280
def test_max_records_reached_stops_reading(self) -> None:
83-
declarative_partition_generator.total_record_counter = 0
84-
8581
expected_records = [
8682
Record(data={"id": 1, "name": "Max"}, stream_name="stream_name"),
8783
Record(data={"id": 1, "name": "Oscar"}, stream_name="stream_name"),
@@ -97,7 +93,7 @@ def test_max_records_reached_stops_reading(self) -> None:
9793

9894
retriever = self._mock_retriever(mock_records)
9995
message_repository = Mock(spec=MessageRepository)
100-
partition_factory = declarative_partition_generator.DeclarativePartitionFactory(
96+
partition_factory = DeclarativePartitionFactory(
10197
_STREAM_NAME,
10298
_SCHEMA_LOADER,
10399
retriever,
@@ -113,8 +109,6 @@ def test_max_records_reached_stops_reading(self) -> None:
113109
assert actual_records == expected_records
114110

115111
def test_max_records_reached_on_previous_partition(self) -> None:
116-
declarative_partition_generator.total_record_counter = 0
117-
118112
expected_records = [
119113
Record(data={"id": 1, "name": "Max"}, stream_name="stream_name"),
120114
Record(data={"id": 1, "name": "Oscar"}, stream_name="stream_name"),
@@ -128,7 +122,7 @@ def test_max_records_reached_on_previous_partition(self) -> None:
128122

129123
retriever = self._mock_retriever(mock_records)
130124
message_repository = Mock(spec=MessageRepository)
131-
partition_factory = declarative_partition_generator.DeclarativePartitionFactory(
125+
partition_factory = DeclarativePartitionFactory(
132126
_STREAM_NAME,
133127
_SCHEMA_LOADER,
134128
retriever,

0 commit comments

Comments
 (0)