Skip to content

Commit d1548c0

Browse files
committed
remove per-partition limits in favor of global limit on SimpleRetriever
1 parent 2580057 commit d1548c0

File tree

8 files changed

+146
-148
lines changed

8 files changed

+146
-148
lines changed

airbyte_cdk/connector_builder/test_reader/message_grouper.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -96,16 +96,6 @@ def get_message_groups(
9696
slice_auxiliary_requests: List[AuxiliaryRequest] = []
9797

9898
while message := next(messages, None):
99-
# Even though we do not emit records beyond the limit in the message group response, we still
100-
# need to process messages off the queue in order to avoid a deadlock that occurs if the amount
101-
# of extracted records exceeds the size of the queue (which has a default of 10,000)
102-
#
103-
# A few other options considered was killing the thread pool, but that doesn't kill in-progress
104-
# threads. We also considered adding another event to the main queue, but this is
105-
# the simplest solution for the time being.
106-
if records_count >= limit:
107-
continue
108-
10999
json_message = airbyte_message_to_json(message)
110100

111101
if is_page_http_request_for_different_stream(json_message, stream_name):

airbyte_cdk/sources/concurrent_source/concurrent_source.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@ def create(
4545
message_repository: MessageRepository,
4646
queue: Optional[Queue[QueueItem]] = None,
4747
timeout_seconds: int = DEFAULT_TIMEOUT_SECONDS,
48-
max_records_per_partition: Optional[int] = None,
4948
) -> "ConcurrentSource":
5049
is_single_threaded = initial_number_of_partitions_to_generate == 1 and num_workers == 1
5150
too_many_generator = (
@@ -68,7 +67,6 @@ def create(
6867
message_repository=message_repository,
6968
initial_number_partitions_to_generate=initial_number_of_partitions_to_generate,
7069
timeout_seconds=timeout_seconds,
71-
max_records_per_partition=max_records_per_partition,
7270
)
7371

7472
def __init__(
@@ -80,7 +78,6 @@ def __init__(
8078
message_repository: MessageRepository = InMemoryMessageRepository(),
8179
initial_number_partitions_to_generate: int = 1,
8280
timeout_seconds: int = DEFAULT_TIMEOUT_SECONDS,
83-
max_records_per_partition: Optional[int] = None,
8481
) -> None:
8582
"""
8683
:param threadpool: The threadpool to submit tasks to
@@ -96,7 +93,6 @@ def __init__(
9693
self._message_repository = message_repository
9794
self._initial_number_partitions_to_generate = initial_number_partitions_to_generate
9895
self._timeout_seconds = timeout_seconds
99-
self._max_records_per_partition = max_records_per_partition
10096

10197
# We set a maxsize to for the main thread to process record items when the queue size grows. This assumes that there are less
10298
# threads generating partitions that than are max number of workers. If it weren't the case, we could have threads only generating
@@ -119,7 +115,6 @@ def read(
119115
PartitionReader(
120116
self._queue,
121117
PartitionLogger(self._slice_logger, self._logger, self._message_repository),
122-
self._max_records_per_partition,
123118
),
124119
)
125120

airbyte_cdk/sources/declarative/concurrent_declarative_source.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ def __init__(
126126
max_concurrent_async_job_count=source_config.get("max_concurrent_async_job_count"),
127127
limit_pages_fetched_per_slice=limits.max_pages_per_slice if limits else None,
128128
limit_slices_fetched=limits.max_slices if limits else None,
129+
limit_max_records=limits.max_records if limits else None,
129130
disable_retries=True if limits else False,
130131
disable_cache=True if limits else False,
131132
)
@@ -170,7 +171,6 @@ def __init__(
170171
slice_logger=self._slice_logger,
171172
queue=queue,
172173
message_repository=self.message_repository,
173-
max_records_per_partition=limits.max_records if limits else None,
174174
)
175175

176176
# TODO: Remove this. This property is necessary to safely migrate Stripe during the transition state.

airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -634,6 +634,7 @@ def __init__(
634634
self,
635635
limit_pages_fetched_per_slice: Optional[int] = None,
636636
limit_slices_fetched: Optional[int] = None,
637+
limit_max_records: Optional[int] = None,
637638
emit_connector_builder_messages: bool = False,
638639
disable_retries: bool = False,
639640
disable_cache: bool = False,
@@ -645,6 +646,7 @@ def __init__(
645646
self._init_mappings()
646647
self._limit_pages_fetched_per_slice = limit_pages_fetched_per_slice
647648
self._limit_slices_fetched = limit_slices_fetched
649+
self._limit_max_records = limit_max_records
648650
self._emit_connector_builder_messages = emit_connector_builder_messages
649651
self._disable_retries = disable_retries
650652
self._disable_cache = disable_cache
@@ -3398,6 +3400,7 @@ def _get_url() -> str:
33983400
ignore_stream_slicer_parameters_on_paginated_requests=ignore_stream_slicer_parameters_on_paginated_requests,
33993401
additional_query_properties=query_properties,
34003402
log_formatter=self._get_log_formatter(log_formatter, name),
3403+
max_records=self._limit_max_records,
34013404
parameters=model.parameters or {},
34023405
)
34033406

airbyte_cdk/sources/declarative/retrievers/simple_retriever.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ class SimpleRetriever(Retriever):
9292
ignore_stream_slicer_parameters_on_paginated_requests: bool = False
9393
additional_query_properties: Optional[QueryProperties] = None
9494
log_formatter: Optional[Callable[[requests.Response], Any]] = None
95+
max_records: Optional[int] = None
9596

9697
def __post_init__(self, parameters: Mapping[str, Any]) -> None:
9798
self._paginator = self.paginator or NoPagination(parameters=parameters)
@@ -101,6 +102,7 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None:
101102
if isinstance(self._name, str)
102103
else self._name
103104
)
105+
self._total_records_read = 0
104106

105107
@property # type: ignore
106108
def name(self) -> str:
@@ -501,6 +503,12 @@ def read_records(
501503
:param stream_slice: The stream slice to read data for
502504
:return: The records read from the API source
503505
"""
506+
507+
# For Connector Builder test read operations, if the max number of records has already been
508+
# reached, we just return without attempted to extract any more records
509+
if self.max_records and self._total_records_read >= self.max_records:
510+
return
511+
504512
_slice = stream_slice or StreamSlice(partition={}, cursor_slice={}) # None-check
505513

506514
most_recent_record_from_slice = None
@@ -529,6 +537,13 @@ def read_records(
529537

530538
yield stream_data
531539

540+
# For Connector Builder test read operations, if the max number of records is reached, we
541+
# exit the process early without emitting more records or attempting to extract more
542+
if self.max_records:
543+
self._total_records_read += 1
544+
if self._total_records_read >= self.max_records:
545+
break
546+
532547
if self.cursor:
533548
self.cursor.close_slice(_slice)
534549
return

airbyte_cdk/sources/streams/concurrent/partition_reader.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -53,14 +53,12 @@ def __init__(
5353
self,
5454
queue: Queue[QueueItem],
5555
partition_logger: Optional[PartitionLogger] = None,
56-
max_records_per_partition: Optional[int] = None,
5756
) -> None:
5857
"""
5958
:param queue: The queue to put the records in.
6059
"""
6160
self._queue = queue
6261
self._partition_logger = partition_logger
63-
self._max_records_per_partition = max_records_per_partition
6462

6563
def process_partition(self, partition: Partition, cursor: Cursor) -> None:
6664
"""
@@ -78,18 +76,9 @@ def process_partition(self, partition: Partition, cursor: Cursor) -> None:
7876
if self._partition_logger:
7977
self._partition_logger.log(partition)
8078

81-
record_count = 0
8279
for record in partition.read():
8380
self._queue.put(record)
8481
cursor.observe(record)
85-
record_count += 1
86-
if (
87-
self._max_records_per_partition
88-
and record_count >= self._max_records_per_partition
89-
):
90-
# We stop processing a partition after exceeding the max_records for Connector
91-
# Builder test reads. The record limit only applies to an individual partition
92-
break
9382
cursor.close_partition(partition)
9483
self._queue.put(PartitionCompleteSentinel(partition, self._IS_SUCCESSFUL))
9584
except Exception as e:

unit_tests/connector_builder/test_message_grouper.py

Lines changed: 0 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -307,126 +307,6 @@ def test_get_grouped_messages_with_logs(mock_entrypoint_read: Mock) -> None:
307307
assert actual_log == expected_logs[i]
308308

309309

310-
@pytest.mark.parametrize(
311-
"request_record_limit, max_record_limit, should_fail",
312-
[
313-
pytest.param(1, 3, False, id="test_create_request_with_record_limit"),
314-
pytest.param(3, 1, True, id="test_create_request_record_limit_exceeds_max"),
315-
],
316-
)
317-
@patch("airbyte_cdk.connector_builder.test_reader.reader.AirbyteEntrypoint.read")
318-
def test_get_grouped_messages_record_limit(
319-
mock_entrypoint_read: Mock, request_record_limit: int, max_record_limit: int, should_fail: bool
320-
) -> None:
321-
stream_name = "hashiras"
322-
url = "https://demonslayers.com/api/v1/hashiras?era=taisho"
323-
request = {
324-
"headers": {"Content-Type": "application/json"},
325-
"method": "GET",
326-
"body": {"content": '{"custom": "field"}'},
327-
}
328-
response = {
329-
"status_code": 200,
330-
"headers": {"field": "value"},
331-
"body": {"content": '{"name": "field"}'},
332-
}
333-
mock_source = make_mock_source(
334-
mock_entrypoint_read,
335-
iter(
336-
[
337-
request_response_log_message(request, response, url, stream_name),
338-
record_message(stream_name, {"name": "Shinobu Kocho"}),
339-
record_message(stream_name, {"name": "Muichiro Tokito"}),
340-
request_response_log_message(request, response, url, stream_name),
341-
record_message(stream_name, {"name": "Mitsuri Kanroji"}),
342-
]
343-
),
344-
)
345-
n_records = 2
346-
record_limit = min(request_record_limit, max_record_limit)
347-
348-
api = TestReader(MAX_PAGES_PER_SLICE, MAX_SLICES, max_record_limit=max_record_limit)
349-
# this is the call we expect to raise an exception
350-
if should_fail:
351-
with pytest.raises(ValueError):
352-
api.run_test_read(
353-
mock_source,
354-
config=CONFIG,
355-
configured_catalog=create_configured_catalog(stream_name),
356-
stream_name=stream_name,
357-
state=_NO_STATE,
358-
record_limit=request_record_limit,
359-
)
360-
else:
361-
actual_response: StreamRead = api.run_test_read(
362-
mock_source,
363-
config=CONFIG,
364-
configured_catalog=create_configured_catalog(stream_name),
365-
stream_name=stream_name,
366-
state=_NO_STATE,
367-
record_limit=request_record_limit,
368-
)
369-
single_slice = actual_response.slices[0]
370-
total_records = 0
371-
for i, actual_page in enumerate(single_slice.pages):
372-
total_records += len(actual_page.records)
373-
assert total_records == min([record_limit, n_records])
374-
375-
assert (total_records >= max_record_limit) == actual_response.test_read_limit_reached
376-
377-
378-
@pytest.mark.parametrize(
379-
"max_record_limit",
380-
[
381-
pytest.param(2, id="test_create_request_no_record_limit"),
382-
pytest.param(1, id="test_create_request_no_record_limit_n_records_exceed_max"),
383-
],
384-
)
385-
@patch("airbyte_cdk.connector_builder.test_reader.reader.AirbyteEntrypoint.read")
386-
def test_get_grouped_messages_default_record_limit(
387-
mock_entrypoint_read: Mock, max_record_limit: int
388-
) -> None:
389-
stream_name = "hashiras"
390-
url = "https://demonslayers.com/api/v1/hashiras?era=taisho"
391-
request = {
392-
"headers": {"Content-Type": "application/json"},
393-
"method": "GET",
394-
"body": {"content": '{"custom": "field"}'},
395-
}
396-
response = {
397-
"status_code": 200,
398-
"headers": {"field": "value"},
399-
"body": {"content": '{"name": "field"}'},
400-
}
401-
mock_source = make_mock_source(
402-
mock_entrypoint_read,
403-
iter(
404-
[
405-
request_response_log_message(request, response, url, stream_name),
406-
record_message(stream_name, {"name": "Shinobu Kocho"}),
407-
record_message(stream_name, {"name": "Muichiro Tokito"}),
408-
request_response_log_message(request, response, url, stream_name),
409-
record_message(stream_name, {"name": "Mitsuri Kanroji"}),
410-
]
411-
),
412-
)
413-
n_records = 2
414-
415-
api = TestReader(MAX_PAGES_PER_SLICE, MAX_SLICES, max_record_limit=max_record_limit)
416-
actual_response: StreamRead = api.run_test_read(
417-
source=mock_source,
418-
config=CONFIG,
419-
configured_catalog=create_configured_catalog(stream_name),
420-
stream_name=stream_name,
421-
state=_NO_STATE,
422-
)
423-
single_slice = actual_response.slices[0]
424-
total_records = 0
425-
for i, actual_page in enumerate(single_slice.pages):
426-
total_records += len(actual_page.records)
427-
assert total_records == min([max_record_limit, n_records])
428-
429-
430310
@patch("airbyte_cdk.connector_builder.test_reader.reader.AirbyteEntrypoint.read")
431311
def test_get_grouped_messages_limit_0(mock_entrypoint_read: Mock) -> None:
432312
stream_name = "hashiras"

0 commit comments

Comments
 (0)