Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1976,7 +1976,10 @@ def create_default_stream(
primary_key = model.primary_key.__root__ if model.primary_key else None

partition_router = self._build_stream_slicer_from_partition_router(
model.retriever, config, stream_name=model.name
model.retriever,
config,
stream_name=model.name,
**kwargs,
)
concurrent_cursor = self._build_concurrent_cursor(model, partition_router, config)
if model.incremental_sync and isinstance(model.incremental_sync, DatetimeBasedCursorModel):
Expand Down Expand Up @@ -2155,10 +2158,11 @@ def _build_stream_slicer_from_partition_router(
],
config: Config,
stream_name: Optional[str] = None,
**kwargs: Any,
) -> PartitionRouter:
if (
hasattr(model, "partition_router")
and isinstance(model, SimpleRetrieverModel | AsyncRetrieverModel)
and isinstance(model, (SimpleRetrieverModel, AsyncRetrieverModel, CustomRetrieverModel))
and model.partition_router
):
stream_slicer_model = model.partition_router
Expand All @@ -2172,6 +2176,23 @@ def _build_stream_slicer_from_partition_router(
],
parameters={},
)
elif isinstance(stream_slicer_model, dict):
# partition router comes from CustomRetrieverModel therefore has not been parsed as a model
params = stream_slicer_model.get("$parameters")
if not isinstance(params, dict):
params = {}
stream_slicer_model["$parameters"] = params

if stream_name is not None:
params["stream_name"] = stream_name

return self._create_nested_component( # type: ignore[no-any-return] # There is no guarantee that this will return a stream slicer. If not, we expect an AttributeError during the call to `stream_slices`
model,
"partition_router",
stream_slicer_model,
config,
**kwargs,
)
else:
return self._create_component_from_model( # type: ignore[no-any-return] # Will be created PartitionRouter as stream_slicer_model is model.partition_router
model=stream_slicer_model, config=config, stream_name=stream_name or ""
Expand Down Expand Up @@ -2886,7 +2907,7 @@ def create_page_increment(
)

def create_parent_stream_config(
self, model: ParentStreamConfigModel, config: Config, stream_name: str, **kwargs: Any
self, model: ParentStreamConfigModel, config: Config, *, stream_name: str, **kwargs: Any
) -> ParentStreamConfig:
declarative_stream = self._create_component_from_model(
model.stream,
Expand Down Expand Up @@ -3693,14 +3714,19 @@ def create_spec(self, model: SpecModel, config: Config, **kwargs: Any) -> Spec:
)

def create_substream_partition_router(
self, model: SubstreamPartitionRouterModel, config: Config, **kwargs: Any
self,
model: SubstreamPartitionRouterModel,
config: Config,
*,
stream_name: str,
**kwargs: Any,
) -> SubstreamPartitionRouter:
parent_stream_configs = []
if model.parent_stream_configs:
parent_stream_configs.extend(
[
self.create_parent_stream_config_with_substream_wrapper(
model=parent_stream_config, config=config, **kwargs
model=parent_stream_config, config=config, stream_name=stream_name, **kwargs
)
for parent_stream_config in model.parent_stream_configs
]
Expand All @@ -3720,7 +3746,7 @@ def create_parent_stream_config_with_substream_wrapper(

# This flag will be used exclusively for StateDelegatingStream when a parent stream is created
has_parent_state = bool(
self._connector_state_manager.get_stream_state(kwargs.get("stream_name", ""), None)
self._connector_state_manager.get_stream_state(stream_name, None)
if model.incremental_dependency
else False
)
Expand Down Expand Up @@ -4113,11 +4139,17 @@ def set_api_budget(self, component_definition: ComponentDefinition, config: Conf
)

def create_grouping_partition_router(
self, model: GroupingPartitionRouterModel, config: Config, **kwargs: Any
self,
model: GroupingPartitionRouterModel,
config: Config,
*,
stream_name: str,
**kwargs: Any,
) -> GroupingPartitionRouter:
underlying_router = self._create_component_from_model(
model=model.underlying_partition_router,
config=config,
stream_name=stream_name,
**kwargs,
)
if model.group_size < 1:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -992,6 +992,202 @@ def test_stream_with_custom_retriever_and_transformations():
assert get_retriever(stream).record_selector.transformations


def test_stream_with_custom_retriever_and_partition_router():
content = """
a_stream:
type: DeclarativeStream
primary_key: id
schema_loader:
type: InlineSchemaLoader
schema:
$schema: "http://json-schema.org/draft-07/schema"
type: object
properties:
id:
type: string
retriever:
type: CustomRetriever
class_name: unit_tests.sources.declarative.parsers.testing_components.TestingCustomRetriever
record_selector:
type: RecordSelector
extractor:
field_path: []
requester:
type: HttpRequester
url_base: "https://api.sendgrid.com/v3/"
http_method: "GET"
partition_router:
type: SubstreamPartitionRouter
parent_stream_configs:
- parent_key: id
partition_field: id
stream:
type: DeclarativeStream
primary_key: id
schema_loader:
type: InlineSchemaLoader
schema:
$schema: "http://json-schema.org/draft-07/schema"
type: object
properties:
id:
type: string
retriever:
type: SimpleRetriever
requester:
type: HttpRequester
url_base: "https://api.sendgrid.com/v3/parent"
http_method: "GET"
record_selector:
type: RecordSelector
extractor:
field_path: []
$parameters:
name: a_stream
"""

parsed_manifest = YamlDeclarativeSource._parse(content)
resolved_manifest = resolver.preprocess_manifest(parsed_manifest)
stream_manifest = transformer.propagate_types_and_parameters(
"", resolved_manifest["a_stream"], {}
)

stream = factory.create_component(
model_type=DeclarativeStreamModel, component_definition=stream_manifest, config=input_config
)

assert isinstance(stream, DefaultStream)
assert isinstance(stream._stream_partition_generator._stream_slicer, SubstreamPartitionRouter)


def test_stream_with_custom_retriever_with_partition_router_field_that_is_not_a_partition_router():
"""
This test documents the behavior where if a custom retriever has a field named partition_router, it will assume
it can generate stream_slices with this parameter. In this test, the partition_router is a RecordSelector that can't
generate stream_slices so there will be an AttributeError.
"""
content = """
a_stream:
type: DeclarativeStream
primary_key: id
schema_loader:
type: InlineSchemaLoader
schema:
$schema: "http://json-schema.org/draft-07/schema"
type: object
properties:
id:
type: string
retriever:
type: CustomRetriever
class_name: unit_tests.sources.declarative.parsers.testing_components.TestingCustomRetriever
record_selector:
type: RecordSelector
extractor:
field_path: []
requester:
type: HttpRequester
url_base: "https://api.sendgrid.com/v3/"
http_method: "GET"
partition_router:
type: RecordSelector
extractor:
field_path: []
$parameters:
name: a_stream
"""

parsed_manifest = YamlDeclarativeSource._parse(content)
resolved_manifest = resolver.preprocess_manifest(parsed_manifest)
stream_manifest = transformer.propagate_types_and_parameters(
"", resolved_manifest["a_stream"], {}
)

stream = factory.create_component(
model_type=DeclarativeStreamModel, component_definition=stream_manifest, config=input_config
)

assert isinstance(stream, DefaultStream)
with pytest.raises(AttributeError) as e:
list(stream.generate_partitions())
assert e.value.args[0] == "'RecordSelector' object has no attribute 'stream_slices'"


def test_incremental_stream_with_custom_retriever_and_partition_router():
content = """
a_stream:
type: DeclarativeStream
primary_key: id
schema_loader:
type: InlineSchemaLoader
schema:
$schema: "http://json-schema.org/draft-07/schema"
type: object
properties:
id:
type: string
incremental_sync:
type: DatetimeBasedCursor
datetime_format: "%Y-%m-%dT%H:%M:%S.%f%z"
start_datetime: "{{ config['start_time'] }}"
cursor_field: "created"
retriever:
type: CustomRetriever
class_name: unit_tests.sources.declarative.parsers.testing_components.TestingCustomRetriever
record_selector:
type: RecordSelector
extractor:
field_path: []
requester:
type: HttpRequester
url_base: "https://api.sendgrid.com/v3/"
http_method: "GET"
partition_router:
type: SubstreamPartitionRouter
parent_stream_configs:
- parent_key: id
partition_field: id
stream:
type: DeclarativeStream
primary_key: id
schema_loader:
type: InlineSchemaLoader
schema:
$schema: "http://json-schema.org/draft-07/schema"
type: object
properties:
id:
type: string
retriever:
type: SimpleRetriever
requester:
type: HttpRequester
url_base: "https://api.sendgrid.com/v3/parent"
http_method: "GET"
record_selector:
type: RecordSelector
extractor:
field_path: []
$parameters:
name: a_stream
"""

parsed_manifest = YamlDeclarativeSource._parse(content)
resolved_manifest = resolver.preprocess_manifest(parsed_manifest)
stream_manifest = transformer.propagate_types_and_parameters(
"", resolved_manifest["a_stream"], {}
)

stream = factory.create_component(
model_type=DeclarativeStreamModel, component_definition=stream_manifest, config=input_config
)

assert isinstance(stream, DefaultStream)
assert isinstance(
stream._stream_partition_generator._stream_slicer, ConcurrentPerPartitionCursor
)


@pytest.mark.parametrize(
"use_legacy_state",
[
Expand Down
Loading