Skip to content

Commit 25ca5b8

Browse files
author
maxime.c
committed
allow for specific parameters to be passed to custom components
1 parent c004637 commit 25ca5b8

File tree

3 files changed

+75
-12
lines changed

3 files changed

+75
-12
lines changed

airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -752,7 +752,7 @@ def _init_mappings(self) -> None:
752752
OAuthAuthenticatorModel: self.create_oauth_authenticator,
753753
OffsetIncrementModel: self.create_offset_increment,
754754
PageIncrementModel: self.create_page_increment,
755-
ParentStreamConfigModel: self.create_parent_stream_config,
755+
ParentStreamConfigModel: self._create_message_repository_substream_wrapper,
756756
PredicateValidatorModel: self.create_predicate_validator,
757757
PropertiesFromEndpointModel: self.create_properties_from_endpoint,
758758
PropertyChunkingModel: self.create_property_chunking,
@@ -1748,7 +1748,7 @@ def create_custom_component(self, model: Any, config: Config, **kwargs: Any) ->
17481748

17491749
if self._is_component(model_value):
17501750
model_args[model_field] = self._create_nested_component(
1751-
model, model_field, model_value, config
1751+
model, model_field, model_value, config, **kwargs,
17521752
)
17531753
elif isinstance(model_value, list):
17541754
vals = []
@@ -1760,7 +1760,7 @@ def create_custom_component(self, model: Any, config: Config, **kwargs: Any) ->
17601760
if derived_type:
17611761
v["type"] = derived_type
17621762
if self._is_component(v):
1763-
vals.append(self._create_nested_component(model, model_field, v, config))
1763+
vals.append(self._create_nested_component(model, model_field, v, config, **kwargs,))
17641764
else:
17651765
vals.append(v)
17661766
model_args[model_field] = vals
@@ -1850,7 +1850,7 @@ def _extract_missing_parameters(error: TypeError) -> List[str]:
18501850
return []
18511851

18521852
def _create_nested_component(
1853-
self, model: Any, model_field: str, model_value: Any, config: Config
1853+
self, model: Any, model_field: str, model_value: Any, config: Config, **kwargs: Any
18541854
) -> Any:
18551855
type_name = model_value.get("type", None)
18561856
if not type_name:
@@ -1875,8 +1875,11 @@ def _create_nested_component(
18751875
for kwarg in constructor_kwargs
18761876
if kwarg in model_parameters
18771877
}
1878+
matching_kwargs = {
1879+
kwarg: kwargs[kwarg] for kwarg in constructor_kwargs if kwarg in kwargs
1880+
}
18781881
return self._create_component_from_model(
1879-
model=parsed_model, config=config, **matching_parameters
1882+
model=parsed_model, config=config, **(matching_parameters | matching_kwargs)
18801883
)
18811884
except TypeError as error:
18821885
missing_parameters = self._extract_missing_parameters(error)
@@ -2871,7 +2874,7 @@ def create_page_increment(
28712874
)
28722875

28732876
def create_parent_stream_config(
2874-
self, model: ParentStreamConfigModel, config: Config, **kwargs: Any
2877+
self, model: ParentStreamConfigModel, config: Config, stream_name: str, **kwargs: Any
28752878
) -> ParentStreamConfig:
28762879
declarative_stream = self._create_component_from_model(
28772880
model.stream,
@@ -3695,11 +3698,11 @@ def create_substream_partition_router(
36953698
)
36963699

36973700
def _create_message_repository_substream_wrapper(
3698-
self, model: ParentStreamConfigModel, config: Config, **kwargs: Any
3701+
self, model: ParentStreamConfigModel, config: Config, *, stream_name: str, **kwargs: Any
36993702
) -> Any:
37003703
# getting the parent state
37013704
child_state = self._connector_state_manager.get_stream_state(
3702-
kwargs["stream_name"], None
3705+
stream_name, None
37033706
)
37043707

37053708
# This flag will be used exclusively for StateDelegatingStream when a parent stream is created
@@ -3731,8 +3734,8 @@ def _create_message_repository_substream_wrapper(
37313734
),
37323735
)
37333736

3734-
return substream_factory._create_component_from_model(
3735-
model=model, config=config, has_parent_state=has_parent_state, **kwargs
3737+
return substream_factory.create_parent_stream_config(
3738+
model=model, config=config, stream_name=stream_name, **kwargs
37363739
)
37373740

37383741
def _instantiate_parent_stream_state_manager(

unit_tests/sources/declarative/parsers/test_model_to_component_factory.py

Lines changed: 56 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,7 @@
163163
from airbyte_cdk.sources.declarative.transformations import AddFields, RemoveFields
164164
from airbyte_cdk.sources.declarative.transformations.add_fields import AddedFieldDefinition
165165
from airbyte_cdk.sources.declarative.yaml_declarative_source import YamlDeclarativeSource
166+
from airbyte_cdk.sources.message.repository import StateFilteringMessageRepository
166167
from airbyte_cdk.sources.streams.call_rate import MovingWindowCallRatePolicy
167168
from airbyte_cdk.sources.streams.concurrent.clamping import (
168169
ClampingEndProvider,
@@ -944,6 +945,58 @@ def test_stream_with_incremental_and_retriever_with_partition_router():
944945
assert list_stream_slicer._cursor_field.string == "a_key"
945946

946947

948+
def test_stream_with_custom_retriever_and_transformations():
949+
content = """
950+
a_stream:
951+
type: DeclarativeStream
952+
primary_key: id
953+
schema_loader:
954+
type: InlineSchemaLoader
955+
schema:
956+
$schema: "http://json-schema.org/draft-07/schema"
957+
type: object
958+
properties:
959+
id:
960+
type: string
961+
retriever:
962+
type: CustomRetriever
963+
class_name: unit_tests.sources.declarative.parsers.testing_components.TestingCustomRetriever
964+
name: "{{ parameters['name'] }}"
965+
decoder:
966+
type: JsonDecoder
967+
requester:
968+
type: HttpRequester
969+
name: "{{ parameters['name'] }}"
970+
url_base: "https://api.sendgrid.com/v3/"
971+
http_method: "GET"
972+
record_selector:
973+
type: RecordSelector
974+
extractor:
975+
type: DpathExtractor
976+
field_path: ["records"]
977+
transformations:
978+
- type: AddFields
979+
fields:
980+
- path: ["extra"]
981+
value: "{{ response.to_add }}"
982+
$parameters:
983+
name: a_stream
984+
"""
985+
986+
parsed_manifest = YamlDeclarativeSource._parse(content)
987+
resolved_manifest = resolver.preprocess_manifest(parsed_manifest)
988+
stream_manifest = transformer.propagate_types_and_parameters(
989+
"", resolved_manifest["a_stream"], {}
990+
)
991+
992+
stream = factory.create_component(
993+
model_type=DeclarativeStreamModel, component_definition=stream_manifest, config=input_config
994+
)
995+
996+
assert isinstance(stream, DefaultStream)
997+
assert get_retriever(stream).record_selector.transformations
998+
999+
9471000
@pytest.mark.parametrize(
9481001
"use_legacy_state",
9491002
[
@@ -2053,11 +2106,12 @@ def test_custom_components_do_not_contain_extra_fields():
20532106
}
20542107

20552108
custom_substream_partition_router = factory.create_component(
2056-
CustomPartitionRouterModel, custom_substream_partition_router_manifest, input_config
2109+
CustomPartitionRouterModel, custom_substream_partition_router_manifest, input_config, stream_name="child_stream_name",
20572110
)
20582111
assert isinstance(custom_substream_partition_router, TestingCustomSubstreamPartitionRouter)
20592112

20602113
assert len(custom_substream_partition_router.parent_stream_configs) == 1
2114+
assert isinstance(custom_substream_partition_router.parent_stream_configs[0].stream.cursor._message_repository, StateFilteringMessageRepository)
20612115
assert custom_substream_partition_router.parent_stream_configs[0].parent_key.eval({}) == "id"
20622116
assert (
20632117
custom_substream_partition_router.parent_stream_configs[0].partition_field.eval({})
@@ -2120,7 +2174,7 @@ def test_parse_custom_component_fields_if_subcomponent():
21202174
}
21212175

21222176
custom_substream_partition_router = factory.create_component(
2123-
CustomPartitionRouterModel, custom_substream_partition_router_manifest, input_config
2177+
CustomPartitionRouterModel, custom_substream_partition_router_manifest, input_config, stream_name="child_stream_name"
21242178
)
21252179
assert isinstance(custom_substream_partition_router, TestingCustomSubstreamPartitionRouter)
21262180
assert custom_substream_partition_router.custom_field == "here"

unit_tests/sources/declarative/parsers/testing_components.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
DefaultPaginator,
1414
PaginationStrategy,
1515
)
16+
from airbyte_cdk.sources.declarative.retrievers import SimpleRetriever
1617

1718

1819
@dataclass
@@ -43,3 +44,8 @@ class TestingCustomSubstreamPartitionRouter(SubstreamPartitionRouter):
4344

4445
custom_field: str
4546
custom_pagination_strategy: PaginationStrategy
47+
48+
49+
@dataclass
50+
class TestingCustomRetriever(SimpleRetriever):
51+
pass

0 commit comments

Comments
 (0)