Skip to content

Commit 1f9522a

Browse files
committed
file-api: move catalog pass from streams() to ModelToComponentFactory.__init__() so we don't mess the interface signature
1 parent 35b293f commit 1f9522a

File tree

3 files changed

+37
-47
lines changed

3 files changed

+37
-47
lines changed

airbyte_cdk/sources/declarative/concurrent_declarative_source.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ def __init__(
8888
emit_connector_builder_messages=emit_connector_builder_messages,
8989
disable_resumable_full_refresh=True,
9090
connector_state_manager=self._connector_state_manager,
91+
catalog=catalog,
9192
)
9293

9394
super().__init__(
@@ -139,7 +140,7 @@ def read(
139140
catalog: ConfiguredAirbyteCatalog,
140141
state: Optional[List[AirbyteStateMessage]] = None,
141142
) -> Iterator[AirbyteMessage]:
142-
concurrent_streams, _ = self._group_streams(config=config, catalog=catalog)
143+
concurrent_streams, _ = self._group_streams(config=config)
143144

144145
# ConcurrentReadProcessor pops streams that are finished being read so before syncing, the names of
145146
# the concurrent streams must be saved so that they can be removed from the catalog before starting
@@ -180,9 +181,7 @@ def discover(self, logger: logging.Logger, config: Mapping[str, Any]) -> Airbyte
180181
]
181182
)
182183

183-
def streams(
184-
self, config: Mapping[str, Any], catalog: ConfiguredAirbyteCatalog | None = None
185-
) -> List[Stream]:
184+
def streams(self, config: Mapping[str, Any]) -> List[Stream]:
186185
"""
187186
The `streams` method is used as part of the AbstractSource in the following cases:
188187
* ConcurrentDeclarativeSource.check -> ManifestDeclarativeSource.check -> AbstractSource.check -> DeclarativeSource.check_connection -> CheckStream.check_connection -> streams
@@ -191,10 +190,10 @@ def streams(
191190
192191
In both case, we will assume that calling the DeclarativeStream is perfectly fine as the result for these is the same regardless of if it is a DeclarativeStream or a DefaultStream (concurrent). This should simply be removed once we have moved away from the mentioned code paths above.
193192
"""
194-
return super().streams(config, catalog=catalog)
193+
return super().streams(config)
195194

196195
def _group_streams(
197-
self, config: Mapping[str, Any], catalog: ConfiguredAirbyteCatalog | None = None
196+
self, config: Mapping[str, Any]
198197
) -> Tuple[List[AbstractStream], List[Stream]]:
199198
concurrent_streams: List[AbstractStream] = []
200199
synchronous_streams: List[Stream] = []
@@ -207,7 +206,7 @@ def _group_streams(
207206

208207
name_to_stream_mapping = {stream["name"]: stream for stream in streams}
209208

210-
for declarative_stream in self.streams(config=config, catalog=catalog):
209+
for declarative_stream in self.streams(config=config):
211210
# Some low-code sources use a combination of DeclarativeStream and regular Python streams. We can't inspect
212211
# these legacy Python streams the way we do low-code streams to determine if they are concurrent compatible,
213212
# so we need to treat them as synchronous

airbyte_cdk/sources/declarative/manifest_declarative_source.py

Lines changed: 3 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
AirbyteMessage,
2121
AirbyteStateMessage,
2222
ConfiguredAirbyteCatalog,
23-
ConfiguredAirbyteStream,
2423
ConnectorSpecification,
2524
FailureType,
2625
)
@@ -93,6 +92,7 @@ def __init__(
9392
emit_connector_builder_messages: bool = False,
9493
component_factory: Optional[ModelToComponentFactory] = None,
9594
normalize_manifest: Optional[bool] = False,
95+
catalog: Optional[ConfiguredAirbyteCatalog] = None,
9696
) -> None:
9797
"""
9898
Args:
@@ -119,6 +119,7 @@ def __init__(
119119
else ModelToComponentFactory(
120120
emit_connector_builder_messages,
121121
max_concurrent_async_job_count=source_config.get("max_concurrent_async_job_count"),
122+
catalog=catalog,
122123
)
123124
)
124125
self._message_repository = self._constructor.get_message_repository()
@@ -230,9 +231,7 @@ def connection_checker(self) -> ConnectionChecker:
230231
f"Expected to generate a ConnectionChecker component, but received {check_stream.__class__}"
231232
)
232233

233-
def streams(
234-
self, config: Mapping[str, Any], catalog: Optional[ConfiguredAirbyteCatalog] = None
235-
) -> List[Stream]:
234+
def streams(self, config: Mapping[str, Any]) -> List[Stream]:
236235
self._emit_manifest_debug_message(
237236
extra_args={
238237
"source_name": self.name,
@@ -248,7 +247,6 @@ def streams(
248247
if api_budget_model:
249248
self._constructor.set_api_budget(api_budget_model, config)
250249

251-
catalog_with_streams_name = self._catalog_with_streams_name(catalog)
252250
source_streams = [
253251
self._constructor.create_component(
254252
(
@@ -259,45 +257,12 @@ def streams(
259257
stream_config,
260258
config,
261259
emit_connector_builder_messages=self._emit_connector_builder_messages,
262-
include_files=self._get_include_files(
263-
stream_config=stream_config, catalog_with_streams_name=catalog_with_streams_name
264-
),
265260
)
266261
for stream_config in self._initialize_cache_for_parent_streams(deepcopy(stream_configs))
267262
]
268263

269264
return source_streams
270265

271-
@staticmethod
272-
def _get_include_files(
273-
stream_config: Dict[str, Any],
274-
catalog_with_streams_name: Mapping[str, ConfiguredAirbyteStream] | None,
275-
) -> bool:
276-
"""
277-
Returns the include_files for the stream if it exists in the catalog.
278-
"""
279-
if catalog_with_streams_name:
280-
stream_name = stream_config.get("name")
281-
configured_catalog_stream = (
282-
catalog_with_streams_name.get(stream_name) if stream_name else None
283-
)
284-
return bool(configured_catalog_stream and configured_catalog_stream.include_files)
285-
return False
286-
287-
@staticmethod
288-
def _catalog_with_streams_name(
289-
catalog: ConfiguredAirbyteCatalog | None,
290-
) -> Mapping[str, ConfiguredAirbyteStream] | None:
291-
"""
292-
Returns a dict mapping stream names to their corresponding ConfiguredAirbyteStream objects.
293-
"""
294-
if catalog:
295-
return {
296-
configured_stream.stream.name: configured_stream
297-
for configured_stream in catalog.streams
298-
}
299-
return None
300-
301266
@staticmethod
302267
def _initialize_cache_for_parent_streams(
303268
stream_configs: List[Dict[str, Any]],

airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from isodate import parse_duration
2828
from pydantic.v1 import BaseModel
2929

30-
from airbyte_cdk.models import FailureType, Level
30+
from airbyte_cdk.models import ConfiguredAirbyteCatalog, ConfiguredAirbyteStream, FailureType, Level
3131
from airbyte_cdk.sources.connector_state_manager import ConnectorStateManager
3232
from airbyte_cdk.sources.declarative.async_job.job_orchestrator import AsyncJobOrchestrator
3333
from airbyte_cdk.sources.declarative.async_job.job_tracker import JobTracker
@@ -576,6 +576,7 @@ def __init__(
576576
disable_retries: bool = False,
577577
disable_cache: bool = False,
578578
disable_resumable_full_refresh: bool = False,
579+
catalog: ConfiguredAirbyteCatalog = ConfiguredAirbyteCatalog(streams=[]),
579580
message_repository: Optional[MessageRepository] = None,
580581
connector_state_manager: Optional[ConnectorStateManager] = None,
581582
max_concurrent_async_job_count: Optional[int] = None,
@@ -593,6 +594,7 @@ def __init__(
593594
self._connector_state_manager = connector_state_manager or ConnectorStateManager()
594595
self._api_budget: Optional[Union[APIBudget, HttpAPIBudget]] = None
595596
self._job_tracker: JobTracker = JobTracker(max_concurrent_async_job_count or 1)
597+
self._catalog_with_streams_name = self._get_catalog_with_streams_name(catalog)
596598

597599
def _init_mappings(self) -> None:
598600
self.PYDANTIC_MODEL_TO_CONSTRUCTOR: Mapping[Type[BaseModel], Callable[..., Any]] = {
@@ -1852,7 +1854,7 @@ def create_declarative_stream(
18521854
)
18531855
file_uploader = None
18541856
if model.file_uploader:
1855-
include_files = kwargs.pop("include_files", False)
1857+
include_files = self._get_include_files(model)
18561858
file_uploader = self._create_component_from_model(
18571859
model=model.file_uploader, config=config, include_files=include_files
18581860
)
@@ -3711,3 +3713,27 @@ def create_grouping_partition_router(
37113713
deduplicate=model.deduplicate if model.deduplicate is not None else True,
37123714
config=config,
37133715
)
3716+
3717+
@staticmethod
3718+
def _get_catalog_with_streams_name(
3719+
catalog: ConfiguredAirbyteCatalog,
3720+
) -> Mapping[str, ConfiguredAirbyteStream]:
3721+
"""
3722+
Returns a dict mapping stream names to their corresponding ConfiguredAirbyteStream objects.
3723+
"""
3724+
return {
3725+
configured_stream.stream.name: configured_stream
3726+
for configured_stream in catalog.streams
3727+
}
3728+
3729+
def _get_include_files(
3730+
self,
3731+
stream_model: DeclarativeStreamModel,
3732+
) -> bool:
3733+
"""
3734+
Returns the include_files for the stream if it exists in the catalog.
3735+
"""
3736+
if stream_model.name and self._catalog_with_streams_name:
3737+
configured_catalog_stream = self._catalog_with_streams_name.get(stream_model.name)
3738+
return bool(configured_catalog_stream and configured_catalog_stream.include_files)
3739+
return False

0 commit comments

Comments
 (0)