Skip to content

Commit 23c9712

Browse files
author
maxime.c
committed
fix connector builder tests and format
1 parent c68ae59 commit 23c9712

File tree

9 files changed

+379
-169
lines changed

9 files changed

+379
-169
lines changed

airbyte_cdk/sources/declarative/incremental/concurrent_partition_cursor.py

Lines changed: 33 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,13 @@
1111
from datetime import timedelta
1212
from typing import Any, Callable, Iterable, List, Mapping, MutableMapping, Optional
1313

14-
from airbyte_cdk.models import AirbyteStateMessage, AirbyteStateBlob, AirbyteStreamState, AirbyteStateType, StreamDescriptor
14+
from airbyte_cdk.models import (
15+
AirbyteStateMessage,
16+
AirbyteStateBlob,
17+
AirbyteStreamState,
18+
AirbyteStateType,
19+
StreamDescriptor,
20+
)
1521
from airbyte_cdk.sources.connector_state_manager import ConnectorStateManager
1622
from airbyte_cdk.sources.declarative.incremental.global_substream_cursor import (
1723
Timer,
@@ -548,21 +554,33 @@ def limit_reached(self) -> bool:
548554
return self._number_of_partitions > self.SWITCH_TO_GLOBAL_LIMIT
549555

550556
@staticmethod
551-
def get_parent_state(stream_state: Optional[StreamState], parent_stream_name: str) -> Optional[AirbyteStateMessage]:
552-
return AirbyteStateMessage(
553-
type=AirbyteStateType.STREAM,
554-
stream=AirbyteStreamState(
555-
stream_descriptor=StreamDescriptor(parent_stream_name, None),
556-
stream_state=AirbyteStateBlob(stream_state["parent_state"][parent_stream_name])
557+
def get_parent_state(
558+
stream_state: Optional[StreamState], parent_stream_name: str
559+
) -> Optional[AirbyteStateMessage]:
560+
return (
561+
AirbyteStateMessage(
562+
type=AirbyteStateType.STREAM,
563+
stream=AirbyteStreamState(
564+
stream_descriptor=StreamDescriptor(parent_stream_name, None),
565+
stream_state=AirbyteStateBlob(stream_state["parent_state"][parent_stream_name]),
566+
),
557567
)
558-
) if stream_state and "parent_state" in stream_state else None
568+
if stream_state and "parent_state" in stream_state
569+
else None
570+
)
559571

560572
@staticmethod
561-
def get_global_state(stream_state: Optional[StreamState], parent_stream_name: str) -> Optional[AirbyteStateMessage]:
562-
return AirbyteStateMessage(
563-
type=AirbyteStateType.STREAM,
564-
stream=AirbyteStreamState(
565-
stream_descriptor=StreamDescriptor(parent_stream_name, None),
566-
stream_state=AirbyteStateBlob(stream_state["state"])
573+
def get_global_state(
574+
stream_state: Optional[StreamState], parent_stream_name: str
575+
) -> Optional[AirbyteStateMessage]:
576+
return (
577+
AirbyteStateMessage(
578+
type=AirbyteStateType.STREAM,
579+
stream=AirbyteStreamState(
580+
stream_descriptor=StreamDescriptor(parent_stream_name, None),
581+
stream_state=AirbyteStateBlob(stream_state["state"]),
582+
),
567583
)
568-
) if stream_state and "state" in stream_state else None
584+
if stream_state and "state" in stream_state
585+
else None
586+
)

airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py

Lines changed: 50 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,15 @@
3333
from airbyte_cdk.connector_builder.models import (
3434
LogMessage as ConnectorBuilderLogMessage,
3535
)
36-
from airbyte_cdk.models import FailureType, Level, AirbyteStateMessage, AirbyteStreamState, AirbyteStateBlob, AirbyteStateType, StreamDescriptor
36+
from airbyte_cdk.models import (
37+
FailureType,
38+
Level,
39+
AirbyteStateMessage,
40+
AirbyteStreamState,
41+
AirbyteStateBlob,
42+
AirbyteStateType,
43+
StreamDescriptor,
44+
)
3745
from airbyte_cdk.sources.connector_state_manager import ConnectorStateManager
3846
from airbyte_cdk.sources.declarative.async_job.job_orchestrator import AsyncJobOrchestrator
3947
from airbyte_cdk.sources.declarative.async_job.job_tracker import JobTracker
@@ -500,8 +508,9 @@
500508
InterpolatedRequestOptionsProvider,
501509
RequestOptionsProvider,
502510
)
503-
from airbyte_cdk.sources.declarative.requesters.request_options.per_partition_request_option_provider import \
504-
PerPartitionRequestOptionsProvider
511+
from airbyte_cdk.sources.declarative.requesters.request_options.per_partition_request_option_provider import (
512+
PerPartitionRequestOptionsProvider,
513+
)
505514
from airbyte_cdk.sources.declarative.requesters.request_path import RequestPath
506515
from airbyte_cdk.sources.declarative.requesters.requester import HttpMethod
507516
from airbyte_cdk.sources.declarative.resolvers import (
@@ -1282,7 +1291,9 @@ def create_concurrent_cursor_from_datetime_based_cursor(
12821291

12831292
# TODO validate and explain why we need to do this...
12841293
component_definition["$parameters"] = component_definition.get("parameters", {})
1285-
parameters = component_definition.get("parameters", component_definition.get("$parameters", {}))
1294+
parameters = component_definition.get(
1295+
"parameters", component_definition.get("$parameters", {})
1296+
)
12861297
datetime_based_cursor_model = model_type.parse_obj(component_definition)
12871298

12881299
if not isinstance(datetime_based_cursor_model, DatetimeBasedCursorModel):
@@ -1596,7 +1607,9 @@ def create_concurrent_cursor_from_perpartition_cursor(
15961607

15971608
interpolated_cursor_field = InterpolatedString.create(
15981609
datetime_based_cursor_model.cursor_field,
1599-
parameters=component_definition.get("parameters", component_definition.get("$parameters", {})), # FIXME validate and explain why we need to do this
1610+
parameters=component_definition.get(
1611+
"parameters", component_definition.get("$parameters", {})
1612+
), # FIXME validate and explain why we need to do this
16001613
)
16011614
cursor_field = CursorField(interpolated_cursor_field.eval(config=config))
16021615

@@ -1973,13 +1986,17 @@ def create_declarative_stream(
19731986
request_options_provider = (
19741987
datetime_request_options_provider
19751988
if not isinstance(concurrent_cursor, ConcurrentPerPartitionCursor)
1976-
else PerPartitionRequestOptionsProvider(partition_router, datetime_request_options_provider)
1989+
else PerPartitionRequestOptionsProvider(
1990+
partition_router, datetime_request_options_provider
1991+
)
19771992
)
19781993
elif model.incremental_sync and isinstance(
19791994
model.incremental_sync, IncrementingCountCursorModel
19801995
):
19811996
if isinstance(concurrent_cursor, ConcurrentPerPartitionCursor):
1982-
raise ValueError("PerPartition does not support per partition states because switching to global state is time based")
1997+
raise ValueError(
1998+
"PerPartition does not support per partition states because switching to global state is time based"
1999+
)
19832000

19842001
cursor_model: IncrementingCountCursorModel = model.incremental_sync # type: ignore
19852002

@@ -2019,7 +2036,9 @@ def create_declarative_stream(
20192036
)
20202037

20212038
stream_slicer: ConcurrentStreamSlicer = (
2022-
partition_router if isinstance(concurrent_cursor, FinalStateCursor) else concurrent_cursor
2039+
partition_router
2040+
if isinstance(concurrent_cursor, FinalStateCursor)
2041+
else concurrent_cursor
20232042
)
20242043
retriever = self._create_component_from_model(
20252044
model=model.retriever,
@@ -2088,8 +2107,7 @@ def create_declarative_stream(
20882107
logger=logging.getLogger(f"airbyte.{stream_name}"),
20892108
# FIXME this is a breaking change compared to the old implementation which used the source name instead
20902109
cursor=concurrent_cursor,
2091-
supports_file_transfer=hasattr(model, "file_uploader")
2092-
and bool(model.file_uploader),
2110+
supports_file_transfer=hasattr(model, "file_uploader") and bool(model.file_uploader),
20932111
)
20942112

20952113
def _is_stop_condition_on_cursor(self, model: DeclarativeStreamModel) -> bool:
@@ -3768,14 +3786,20 @@ def _create_message_repository_substream_wrapper(
37683786
self, model: ParentStreamConfigModel, config: Config, **kwargs: Any
37693787
) -> Any:
37703788
# getting the parent state
3771-
child_state = self._connector_state_manager.get_stream_state(kwargs["stream_name"], None) # FIXME adding `stream_name` as a parameter means it will be a breaking change. I assume this is mostly called internally so I don't think we need to bother that much about this but still raising the flag
3789+
child_state = self._connector_state_manager.get_stream_state(
3790+
kwargs["stream_name"], None
3791+
) # FIXME adding `stream_name` as a parameter means it will be a breaking change. I assume this is mostly called internally so I don't think we need to bother that much about this but still raising the flag
37723792
if model.incremental_dependency and child_state:
37733793
parent_stream_name = model.stream.name or ""
3774-
parent_state = ConcurrentPerPartitionCursor.get_parent_state(child_state, parent_stream_name)
3794+
parent_state = ConcurrentPerPartitionCursor.get_parent_state(
3795+
child_state, parent_stream_name
3796+
)
37753797

37763798
if not parent_state:
37773799
# there are two migration cases: state value from child stream or from global state
3778-
parent_state = ConcurrentPerPartitionCursor.get_global_state(child_state, parent_stream_name)
3800+
parent_state = ConcurrentPerPartitionCursor.get_global_state(
3801+
child_state, parent_stream_name
3802+
)
37793803

37803804
if not parent_state and not isinstance(parent_state, dict):
37813805
cursor_field = InterpolatedString.create(
@@ -3787,8 +3811,12 @@ def _create_message_repository_substream_wrapper(
37873811
parent_state = AirbyteStateMessage(
37883812
type=AirbyteStateType.STREAM,
37893813
stream=AirbyteStreamState(
3790-
stream_descriptor=StreamDescriptor(name=parent_stream_name, namespace=None),
3791-
stream_state=AirbyteStateBlob({cursor_field: list(cursor_values)[0]}),
3814+
stream_descriptor=StreamDescriptor(
3815+
name=parent_stream_name, namespace=None
3816+
),
3817+
stream_state=AirbyteStateBlob(
3818+
{cursor_field: list(cursor_values)[0]}
3819+
),
37923820
),
37933821
)
37943822
connector_state_manager = ConnectorStateManager([parent_state] if parent_state else [])
@@ -3804,7 +3832,10 @@ def _create_message_repository_substream_wrapper(
38043832
disable_cache=self._disable_cache,
38053833
message_repository=StateFilteringMessageRepository(
38063834
LogAppenderMessageRepositoryDecorator(
3807-
{"airbyte_cdk": {"stream": {"is_substream": True}}, "http": {"is_auxiliary": True}},
3835+
{
3836+
"airbyte_cdk": {"stream": {"is_substream": True}},
3837+
"http": {"is_auxiliary": True},
3838+
},
38083839
self._message_repository,
38093840
self._evaluate_log_level(self._emit_connector_builder_messages),
38103841
),
@@ -4127,7 +4158,9 @@ def create_grouping_partition_router(
41274158
self, model: GroupingPartitionRouterModel, config: Config, **kwargs: Any
41284159
) -> GroupingPartitionRouter:
41294160
underlying_router = self._create_component_from_model(
4130-
model=model.underlying_partition_router, config=config, **kwargs,
4161+
model=model.underlying_partition_router,
4162+
config=config,
4163+
**kwargs,
41314164
)
41324165
if model.group_size < 1:
41334166
raise ValueError(f"Group size must be greater than 0, got {model.group_size}")

airbyte_cdk/sources/declarative/partition_routers/substream_partition_router.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,17 @@
77
import json
88
import logging
99
from dataclasses import InitVar, dataclass
10-
from typing import TYPE_CHECKING, Any, Iterable, List, Mapping, MutableMapping, Optional, Union, TypeVar
10+
from typing import (
11+
TYPE_CHECKING,
12+
Any,
13+
Iterable,
14+
List,
15+
Mapping,
16+
MutableMapping,
17+
Optional,
18+
Union,
19+
TypeVar,
20+
)
1121

1222
import dpath
1323
import requests
@@ -27,7 +37,6 @@
2737

2838

2939
def iterate_with_last_flag(generator: Iterable[Partition]) -> Iterable[tuple[Partition, bool]]:
30-
3140
iterator = iter(generator)
3241

3342
try:
@@ -191,8 +200,12 @@ def stream_slices(self) -> Iterable[StreamSlice]:
191200
for field_path in parent_stream_config.extra_fields
192201
]
193202

194-
for partition, is_last_slice in iterate_with_last_flag(parent_stream.generate_partitions()):
195-
for parent_record, is_last_record_in_slice in iterate_with_last_flag(partition.read()):
203+
for partition, is_last_slice in iterate_with_last_flag(
204+
parent_stream.generate_partitions()
205+
):
206+
for parent_record, is_last_record_in_slice in iterate_with_last_flag(
207+
partition.read()
208+
):
196209
parent_stream.cursor.observe(parent_record)
197210
parent_partition = (
198211
parent_record.associated_slice.partition
@@ -211,7 +224,9 @@ def stream_slices(self) -> Iterable[StreamSlice]:
211224
continue
212225

213226
# Add extra fields
214-
extracted_extra_fields = self._extract_extra_fields(record_data, extra_fields)
227+
extracted_extra_fields = self._extract_extra_fields(
228+
record_data, extra_fields
229+
)
215230

216231
if parent_stream_config.lazy_read_pointer:
217232
extracted_extra_fields = {
@@ -421,7 +436,9 @@ def get_stream_state(self) -> Optional[Mapping[str, StreamState]]:
421436
parent_state = {}
422437
for parent_config in self.parent_stream_configs:
423438
if parent_config.incremental_dependency:
424-
parent_state[parent_config.stream.name] = copy.deepcopy(parent_config.stream.cursor.state)
439+
parent_state[parent_config.stream.name] = copy.deepcopy(
440+
parent_config.stream.cursor.state
441+
)
425442
return parent_state
426443

427444
@property

unit_tests/connector_builder/test_connector_builder_handler.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1140,7 +1140,9 @@ def test_read_source(mock_http_stream):
11401140
for s in streams:
11411141
retriever = get_retriever(s)
11421142
assert isinstance(retriever, SimpleRetriever)
1143-
assert isinstance(retriever.stream_slicer, StreamSlicerTestReadDecorator)
1143+
assert isinstance(
1144+
s._stream_partition_generator._stream_slicer, StreamSlicerTestReadDecorator
1145+
)
11441146

11451147

11461148
@patch.object(
@@ -1188,7 +1190,9 @@ def test_read_source_single_page_single_slice(mock_http_stream):
11881190
for s in streams:
11891191
retriever = get_retriever(s)
11901192
assert isinstance(retriever, SimpleRetriever)
1191-
assert isinstance(retriever.stream_slicer, StreamSlicerTestReadDecorator)
1193+
assert isinstance(
1194+
s._stream_partition_generator._stream_slicer, StreamSlicerTestReadDecorator
1195+
)
11921196

11931197

11941198
@pytest.mark.parametrize(

unit_tests/sources/declarative/incremental/test_concurrent_perpartitioncursor.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1185,8 +1185,8 @@ def run_incremental_parent_state_test(
11851185
),
11861186
# FIXME this is an interesting case. The previous solution would not update the parent state until `ensure_at_least_one_state_emitted` but the concurrent cursor does just before which is probably fine too
11871187
(
1188-
f"https://api.example.com/community/posts?per_page=100&start_time={POST_1_UPDATED_AT}",
1189-
{"posts": [{"id": 1, "updated_at": POST_1_UPDATED_AT}]},
1188+
f"https://api.example.com/community/posts?per_page=100&start_time={POST_1_UPDATED_AT}",
1189+
{"posts": [{"id": 1, "updated_at": POST_1_UPDATED_AT}]},
11901190
),
11911191
# Fetch the first page of comments for post 1
11921192
(
@@ -1483,8 +1483,8 @@ def run_incremental_parent_state_test(
14831483
),
14841484
# FIXME this is an interesting case. The previous solution would not update the parent state until `ensure_at_least_one_state_emitted` but the concurrent cursor does just before which is probably fine too
14851485
(
1486-
f"https://api.example.com/community/posts?per_page=100&start_time={POST_1_UPDATED_AT}",
1487-
{"posts": [{"id": 1, "updated_at": POST_1_UPDATED_AT}]},
1486+
f"https://api.example.com/community/posts?per_page=100&start_time={POST_1_UPDATED_AT}",
1487+
{"posts": [{"id": 1, "updated_at": POST_1_UPDATED_AT}]},
14881488
),
14891489
# Fetch the first page of comments for post 1
14901490
(
@@ -1629,8 +1629,8 @@ def run_incremental_parent_state_test(
16291629
),
16301630
# FIXME this is an interesting case. The previous solution would not update the parent state until `ensure_at_least_one_state_emitted` but the concurrent cursor does just before which is probably fine too
16311631
(
1632-
f"https://api.example.com/community/posts?per_page=100&start_time={POST_1_UPDATED_AT}",
1633-
{"posts": [{"id": 1, "updated_at": POST_1_UPDATED_AT}]},
1632+
f"https://api.example.com/community/posts?per_page=100&start_time={POST_1_UPDATED_AT}",
1633+
{"posts": [{"id": 1, "updated_at": POST_1_UPDATED_AT}]},
16341634
),
16351635
# Fetch the first page of comments for post 1
16361636
(
@@ -2130,7 +2130,9 @@ def test_incremental_parent_state_migration(
21302130
"states": [
21312131
{
21322132
"partition": {"id": 1, "parent_slice": {}},
2133-
"cursor": {"updated_at": START_DATE}, # FIXME this happens because the concurrent framework gets the start date as the max between the state value and the start value. In this case, the start value is higher
2133+
"cursor": {
2134+
"updated_at": START_DATE
2135+
}, # FIXME this happens because the concurrent framework gets the start date as the max between the state value and the start value. In this case, the start value is higher
21342136
}
21352137
],
21362138
"lookback_window": 0, # FIXME the concurrent framework sets the lookback window to 0 as opposed to the declarative framework which would set not define it

0 commit comments

Comments
 (0)