Skip to content

Commit acc2e1f

Browse files
committed
move limiting to global variable on DeclarativePartition
1 parent d1548c0 commit acc2e1f

File tree

6 files changed

+121
-166
lines changed

6 files changed

+121
-166
lines changed

airbyte_cdk/sources/declarative/concurrent_declarative_source.py

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,6 @@ def __init__(
126126
max_concurrent_async_job_count=source_config.get("max_concurrent_async_job_count"),
127127
limit_pages_fetched_per_slice=limits.max_pages_per_slice if limits else None,
128128
limit_slices_fetched=limits.max_slices if limits else None,
129-
limit_max_records=limits.max_records if limits else None,
130129
disable_retries=True if limits else False,
131130
disable_cache=True if limits else False,
132131
)
@@ -325,10 +324,13 @@ def _group_streams(
325324

326325
partition_generator = StreamSlicerPartitionGenerator(
327326
partition_factory=DeclarativePartitionFactory(
328-
declarative_stream.name,
329-
declarative_stream.get_json_schema(),
330-
retriever,
331-
self.message_repository,
327+
stream_name=declarative_stream.name,
328+
json_schema=declarative_stream.get_json_schema(),
329+
retriever=retriever,
330+
message_repository=self.message_repository,
331+
max_records_limit=self._limits.max_records
332+
if self._limits
333+
else None,
332334
),
333335
stream_slicer=declarative_stream.retriever.stream_slicer,
334336
slice_limit=self._limits.max_slices
@@ -359,10 +361,13 @@ def _group_streams(
359361
)
360362
partition_generator = StreamSlicerPartitionGenerator(
361363
partition_factory=DeclarativePartitionFactory(
362-
declarative_stream.name,
363-
declarative_stream.get_json_schema(),
364-
retriever,
365-
self.message_repository,
364+
stream_name=declarative_stream.name,
365+
json_schema=declarative_stream.get_json_schema(),
366+
retriever=retriever,
367+
message_repository=self.message_repository,
368+
max_records_limit=self._limits.max_records
369+
if self._limits
370+
else None,
366371
),
367372
stream_slicer=cursor,
368373
slice_limit=self._limits.max_slices if self._limits else None,
@@ -391,10 +396,11 @@ def _group_streams(
391396
) and hasattr(declarative_stream.retriever, "stream_slicer"):
392397
partition_generator = StreamSlicerPartitionGenerator(
393398
DeclarativePartitionFactory(
394-
declarative_stream.name,
395-
declarative_stream.get_json_schema(),
396-
declarative_stream.retriever,
397-
self.message_repository,
399+
stream_name=declarative_stream.name,
400+
json_schema=declarative_stream.get_json_schema(),
401+
retriever=declarative_stream.retriever,
402+
message_repository=self.message_repository,
403+
max_records_limit=self._limits.max_records if self._limits else None,
398404
),
399405
declarative_stream.retriever.stream_slicer,
400406
slice_limit=self._limits.max_slices
@@ -455,10 +461,11 @@ def _group_streams(
455461

456462
partition_generator = StreamSlicerPartitionGenerator(
457463
DeclarativePartitionFactory(
458-
declarative_stream.name,
459-
declarative_stream.get_json_schema(),
460-
retriever,
461-
self.message_repository,
464+
stream_name=declarative_stream.name,
465+
json_schema=declarative_stream.get_json_schema(),
466+
retriever=retriever,
467+
message_repository=self.message_repository,
468+
max_records_limit=self._limits.max_records if self._limits else None,
462469
),
463470
perpartition_cursor,
464471
slice_limit=self._limits.max_slices if self._limits else None,

airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -634,7 +634,6 @@ def __init__(
634634
self,
635635
limit_pages_fetched_per_slice: Optional[int] = None,
636636
limit_slices_fetched: Optional[int] = None,
637-
limit_max_records: Optional[int] = None,
638637
emit_connector_builder_messages: bool = False,
639638
disable_retries: bool = False,
640639
disable_cache: bool = False,
@@ -646,7 +645,6 @@ def __init__(
646645
self._init_mappings()
647646
self._limit_pages_fetched_per_slice = limit_pages_fetched_per_slice
648647
self._limit_slices_fetched = limit_slices_fetched
649-
self._limit_max_records = limit_max_records
650648
self._emit_connector_builder_messages = emit_connector_builder_messages
651649
self._disable_retries = disable_retries
652650
self._disable_cache = disable_cache
@@ -3400,7 +3398,6 @@ def _get_url() -> str:
34003398
ignore_stream_slicer_parameters_on_paginated_requests=ignore_stream_slicer_parameters_on_paginated_requests,
34013399
additional_query_properties=query_properties,
34023400
log_formatter=self._get_log_formatter(log_formatter, name),
3403-
max_records=self._limit_max_records,
34043401
parameters=model.parameters or {},
34053402
)
34063403

airbyte_cdk/sources/declarative/retrievers/simple_retriever.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,6 @@ class SimpleRetriever(Retriever):
9292
ignore_stream_slicer_parameters_on_paginated_requests: bool = False
9393
additional_query_properties: Optional[QueryProperties] = None
9494
log_formatter: Optional[Callable[[requests.Response], Any]] = None
95-
max_records: Optional[int] = None
9695

9796
def __post_init__(self, parameters: Mapping[str, Any]) -> None:
9897
self._paginator = self.paginator or NoPagination(parameters=parameters)
@@ -102,7 +101,6 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None:
102101
if isinstance(self._name, str)
103102
else self._name
104103
)
105-
self._total_records_read = 0
106104

107105
@property # type: ignore
108106
def name(self) -> str:
@@ -503,12 +501,6 @@ def read_records(
503501
:param stream_slice: The stream slice to read data for
504502
:return: The records read from the API source
505503
"""
506-
507-
# For Connector Builder test read operations, if the max number of records has already been
508-
# reached, we just return without attempted to extract any more records
509-
if self.max_records and self._total_records_read >= self.max_records:
510-
return
511-
512504
_slice = stream_slice or StreamSlice(partition={}, cursor_slice={}) # None-check
513505

514506
most_recent_record_from_slice = None
@@ -537,13 +529,6 @@ def read_records(
537529

538530
yield stream_data
539531

540-
# For Connector Builder test read operations, if the max number of records is reached, we
541-
# exit the process early without emitting more records or attempting to extract more
542-
if self.max_records:
543-
self._total_records_read += 1
544-
if self._total_records_read >= self.max_records:
545-
break
546-
547532
if self.cursor:
548533
self.cursor.close_slice(_slice)
549534
return

airbyte_cdk/sources/declarative/stream_slicers/declarative_partition_generator.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,11 @@
1313
from airbyte_cdk.sources.types import Record, StreamSlice
1414
from airbyte_cdk.utils.slice_hasher import SliceHasher
1515

16+
# For Connector Builder test read operations, we track the total number of records
17+
# read for the stream at the global level so that we can stop reading early if we
18+
# exceed the record limit
19+
total_record_counter = 0
20+
1621

1722
class DeclarativePartitionFactory:
1823
def __init__(
@@ -21,6 +26,7 @@ def __init__(
2126
json_schema: Mapping[str, Any],
2227
retriever: Retriever,
2328
message_repository: MessageRepository,
29+
max_records_limit: Optional[int] = None,
2430
) -> None:
2531
"""
2632
The DeclarativePartitionFactory takes a retriever_factory and not a retriever directly. The reason is that our components are not
@@ -31,14 +37,16 @@ def __init__(
3137
self._json_schema = json_schema
3238
self._retriever = retriever
3339
self._message_repository = message_repository
40+
self._max_records_limit = max_records_limit
3441

3542
def create(self, stream_slice: StreamSlice) -> Partition:
3643
return DeclarativePartition(
37-
self._stream_name,
38-
self._json_schema,
39-
self._retriever,
40-
self._message_repository,
41-
stream_slice,
44+
stream_name=self._stream_name,
45+
json_schema=self._json_schema,
46+
retriever=self._retriever,
47+
message_repository=self._message_repository,
48+
max_records_limit=self._max_records_limit,
49+
stream_slice=stream_slice,
4250
)
4351

4452

@@ -49,17 +57,24 @@ def __init__(
4957
json_schema: Mapping[str, Any],
5058
retriever: Retriever,
5159
message_repository: MessageRepository,
60+
max_records_limit: Optional[int],
5261
stream_slice: StreamSlice,
5362
):
5463
self._stream_name = stream_name
5564
self._json_schema = json_schema
5665
self._retriever = retriever
5766
self._message_repository = message_repository
67+
self._max_records_limit = max_records_limit
5868
self._stream_slice = stream_slice
5969
self._hash = SliceHasher.hash(self._stream_name, self._stream_slice)
6070

6171
def read(self) -> Iterable[Record]:
6272
for stream_data in self._retriever.read_records(self._json_schema, self._stream_slice):
73+
if self._max_records_limit:
74+
global total_record_counter
75+
if total_record_counter >= self._max_records_limit:
76+
break
77+
6378
if isinstance(stream_data, Mapping):
6479
record = (
6580
stream_data
@@ -74,6 +89,9 @@ def read(self) -> Iterable[Record]:
7489
else:
7590
self._message_repository.emit_message(stream_data)
7691

92+
if self._max_records_limit:
93+
total_record_counter += 1
94+
7795
def to_slice(self) -> Optional[Mapping[str, Any]]:
7896
return self._stream_slice
7997

@@ -90,6 +108,7 @@ def __init__(
90108
partition_factory: DeclarativePartitionFactory,
91109
stream_slicer: StreamSlicer,
92110
slice_limit: Optional[int] = None,
111+
max_records_limit: Optional[int] = None,
93112
) -> None:
94113
self._partition_factory = partition_factory
95114

unit_tests/sources/declarative/retrievers/test_simple_retriever.py

Lines changed: 0 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -1566,122 +1566,3 @@ def test_simple_retriever_still_emit_records_if_no_merge_key():
15661566

15671567
assert len(actual_records) == 10
15681568
assert actual_records == expected_records
1569-
1570-
1571-
def test_simple_retriever_max_records_reached():
1572-
expected_records = [
1573-
Record(data={"id": 1, "name": "Max"}, stream_name="stream_name"),
1574-
Record(data={"id": 1, "name": "Oscar"}, stream_name="stream_name"),
1575-
Record(data={"id": 1, "name": "Charles"}, stream_name="stream_name"),
1576-
Record(data={"id": 1, "name": "Alex"}, stream_name="stream_name"),
1577-
Record(data={"id": 1, "name": "Yuki"}, stream_name="stream_name"),
1578-
]
1579-
1580-
mock_records = expected_records + [
1581-
Record(data={"id": 1, "name": "Lewis"}, stream_name="stream_name"),
1582-
Record(data={"id": 1, "name": "Lando"}, stream_name="stream_name"),
1583-
]
1584-
1585-
record_selector = MagicMock()
1586-
record_selector.select_records.return_value = []
1587-
1588-
retriever = SimpleRetriever(
1589-
name="stream_name",
1590-
primary_key=primary_key,
1591-
requester=MagicMock(),
1592-
paginator=Mock(),
1593-
record_selector=record_selector,
1594-
max_records=5,
1595-
parameters={},
1596-
config={},
1597-
)
1598-
1599-
stream_slice = StreamSlice(cursor_slice={}, partition={"repository": "airbyte"})
1600-
1601-
with patch.object(
1602-
SimpleRetriever,
1603-
"_read_pages",
1604-
return_value=iter(mock_records),
1605-
# side_effect=retriever_read_pages,
1606-
):
1607-
actual_records = list(retriever.read_records(stream_slice=stream_slice, records_schema={}))
1608-
1609-
assert len(actual_records) == 5
1610-
assert actual_records == expected_records
1611-
1612-
1613-
def test_simple_retriever_max_records_already_reached_on_previous_read():
1614-
mock_records = [
1615-
Record(data={"id": 1, "name": "Max"}, stream_name="stream_name"),
1616-
Record(data={"id": 1, "name": "Oscar"}, stream_name="stream_name"),
1617-
Record(data={"id": 1, "name": "Charles"}, stream_name="stream_name"),
1618-
Record(data={"id": 1, "name": "Alex"}, stream_name="stream_name"),
1619-
Record(data={"id": 1, "name": "Yuki"}, stream_name="stream_name"),
1620-
]
1621-
1622-
record_selector = MagicMock()
1623-
record_selector.select_records.return_value = []
1624-
1625-
retriever = SimpleRetriever(
1626-
name="stream_name",
1627-
primary_key=primary_key,
1628-
requester=MagicMock(),
1629-
paginator=Mock(),
1630-
record_selector=record_selector,
1631-
max_records=5,
1632-
parameters={},
1633-
config={},
1634-
)
1635-
retriever._total_records_read = 5
1636-
1637-
stream_slice = StreamSlice(cursor_slice={}, partition={"repository": "airbyte"})
1638-
1639-
with patch.object(
1640-
SimpleRetriever,
1641-
"_read_pages",
1642-
return_value=iter(mock_records),
1643-
# side_effect=retriever_read_pages,
1644-
):
1645-
actual_records = list(retriever.read_records(stream_slice=stream_slice, records_schema={}))
1646-
1647-
assert len(actual_records) == 0
1648-
1649-
1650-
def test_simple_retriever_read_some_records():
1651-
expected_records = [
1652-
Record(data={"id": 1, "name": "Max"}, stream_name="stream_name"),
1653-
Record(data={"id": 1, "name": "Oscar"}, stream_name="stream_name"),
1654-
Record(data={"id": 1, "name": "Charles"}, stream_name="stream_name"),
1655-
]
1656-
1657-
mock_records = expected_records + [
1658-
Record(data={"id": 1, "name": "Alex"}, stream_name="stream_name"),
1659-
Record(data={"id": 1, "name": "Yuki"}, stream_name="stream_name"),
1660-
]
1661-
1662-
record_selector = MagicMock()
1663-
record_selector.select_records.return_value = []
1664-
1665-
retriever = SimpleRetriever(
1666-
name="stream_name",
1667-
primary_key=primary_key,
1668-
requester=MagicMock(),
1669-
paginator=Mock(),
1670-
record_selector=record_selector,
1671-
max_records=5,
1672-
parameters={},
1673-
config={},
1674-
)
1675-
retriever._total_records_read = 2
1676-
1677-
stream_slice = StreamSlice(cursor_slice={}, partition={"repository": "airbyte"})
1678-
1679-
with patch.object(
1680-
SimpleRetriever,
1681-
"_read_pages",
1682-
return_value=iter(mock_records),
1683-
):
1684-
actual_records = list(retriever.read_records(stream_slice=stream_slice, records_schema={}))
1685-
1686-
assert len(actual_records) == 3
1687-
assert actual_records == expected_records

0 commit comments

Comments
 (0)