Skip to content

Commit 90eeaa6

Browse files
author
maxime.c
committed
Merge branch 'main' into maxi297/incremental_without_partition_router_as_defaultstream
2 parents e996805 + 02246dc commit 90eeaa6

25 files changed

+359
-523
lines changed

airbyte_cdk/connector_builder/connector_builder_handler.py

Lines changed: 36 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
#
44

55

6-
from dataclasses import asdict
7-
from typing import Any, Dict, List, Mapping, Optional
6+
from dataclasses import asdict, dataclass, field
7+
from typing import Any, ClassVar, Dict, List, Mapping
88

99
from airbyte_cdk.connector_builder.test_reader import TestReader
1010
from airbyte_cdk.models import (
@@ -15,32 +15,45 @@
1515
Type,
1616
)
1717
from airbyte_cdk.models import Type as MessageType
18-
from airbyte_cdk.sources.declarative.concurrent_declarative_source import (
19-
ConcurrentDeclarativeSource,
20-
TestLimits,
21-
)
2218
from airbyte_cdk.sources.declarative.declarative_source import DeclarativeSource
2319
from airbyte_cdk.sources.declarative.manifest_declarative_source import ManifestDeclarativeSource
20+
from airbyte_cdk.sources.declarative.parsers.model_to_component_factory import (
21+
ModelToComponentFactory,
22+
)
2423
from airbyte_cdk.utils.airbyte_secrets_utils import filter_secrets
2524
from airbyte_cdk.utils.datetime_helpers import ab_datetime_now
2625
from airbyte_cdk.utils.traced_exception import AirbyteTracedException
2726

27+
DEFAULT_MAXIMUM_NUMBER_OF_PAGES_PER_SLICE = 5
28+
DEFAULT_MAXIMUM_NUMBER_OF_SLICES = 5
29+
DEFAULT_MAXIMUM_RECORDS = 100
30+
DEFAULT_MAXIMUM_STREAMS = 100
31+
2832
MAX_PAGES_PER_SLICE_KEY = "max_pages_per_slice"
2933
MAX_SLICES_KEY = "max_slices"
3034
MAX_RECORDS_KEY = "max_records"
3135
MAX_STREAMS_KEY = "max_streams"
3236

3337

38+
@dataclass
39+
class TestLimits:
40+
__test__: ClassVar[bool] = False # Tell Pytest this is not a Pytest class, despite its name
41+
42+
max_records: int = field(default=DEFAULT_MAXIMUM_RECORDS)
43+
max_pages_per_slice: int = field(default=DEFAULT_MAXIMUM_NUMBER_OF_PAGES_PER_SLICE)
44+
max_slices: int = field(default=DEFAULT_MAXIMUM_NUMBER_OF_SLICES)
45+
max_streams: int = field(default=DEFAULT_MAXIMUM_STREAMS)
46+
47+
3448
def get_limits(config: Mapping[str, Any]) -> TestLimits:
3549
command_config = config.get("__test_read_config", {})
36-
return TestLimits(
37-
max_records=command_config.get(MAX_RECORDS_KEY, TestLimits.DEFAULT_MAX_RECORDS),
38-
max_pages_per_slice=command_config.get(
39-
MAX_PAGES_PER_SLICE_KEY, TestLimits.DEFAULT_MAX_PAGES_PER_SLICE
40-
),
41-
max_slices=command_config.get(MAX_SLICES_KEY, TestLimits.DEFAULT_MAX_SLICES),
42-
max_streams=command_config.get(MAX_STREAMS_KEY, TestLimits.DEFAULT_MAX_STREAMS),
50+
max_pages_per_slice = (
51+
command_config.get(MAX_PAGES_PER_SLICE_KEY) or DEFAULT_MAXIMUM_NUMBER_OF_PAGES_PER_SLICE
4352
)
53+
max_slices = command_config.get(MAX_SLICES_KEY) or DEFAULT_MAXIMUM_NUMBER_OF_SLICES
54+
max_records = command_config.get(MAX_RECORDS_KEY) or DEFAULT_MAXIMUM_RECORDS
55+
max_streams = command_config.get(MAX_STREAMS_KEY) or DEFAULT_MAXIMUM_STREAMS
56+
return TestLimits(max_records, max_pages_per_slice, max_slices, max_streams)
4457

4558

4659
def should_migrate_manifest(config: Mapping[str, Any]) -> bool:
@@ -62,30 +75,21 @@ def should_normalize_manifest(config: Mapping[str, Any]) -> bool:
6275
return config.get("__should_normalize", False)
6376

6477

65-
def create_source(
66-
config: Mapping[str, Any],
67-
limits: TestLimits,
68-
catalog: Optional[ConfiguredAirbyteCatalog],
69-
state: Optional[List[AirbyteStateMessage]],
70-
) -> ConcurrentDeclarativeSource[Optional[List[AirbyteStateMessage]]]:
78+
def create_source(config: Mapping[str, Any], limits: TestLimits) -> ManifestDeclarativeSource:
7179
manifest = config["__injected_declarative_manifest"]
72-
73-
# We enforce a concurrency level of 1 so that the stream is processed on a single thread
74-
# to retain ordering for the grouping of the builder message responses.
75-
if "concurrency_level" in manifest:
76-
manifest["concurrency_level"]["default_concurrency"] = 1
77-
else:
78-
manifest["concurrency_level"] = {"type": "ConcurrencyLevel", "default_concurrency": 1}
79-
80-
return ConcurrentDeclarativeSource(
81-
catalog=catalog,
80+
return ManifestDeclarativeSource(
8281
config=config,
83-
state=state,
84-
source_config=manifest,
8582
emit_connector_builder_messages=True,
83+
source_config=manifest,
8684
migrate_manifest=should_migrate_manifest(config),
8785
normalize_manifest=should_normalize_manifest(config),
88-
limits=limits,
86+
component_factory=ModelToComponentFactory(
87+
emit_connector_builder_messages=True,
88+
limit_pages_fetched_per_slice=limits.max_pages_per_slice,
89+
limit_slices_fetched=limits.max_slices,
90+
disable_retries=True,
91+
disable_cache=True,
92+
),
8993
)
9094

9195

airbyte_cdk/connector_builder/main.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,12 +91,12 @@ def handle_connector_builder_request(
9191
def handle_request(args: List[str]) -> str:
9292
command, config, catalog, state = get_config_and_catalog_from_args(args)
9393
limits = get_limits(config)
94-
source = create_source(config=config, limits=limits, catalog=catalog, state=state)
95-
return orjson.dumps( # type: ignore[no-any-return] # Serializer.dump() always returns AirbyteMessage
94+
source = create_source(config, limits)
95+
return orjson.dumps(
9696
AirbyteMessageSerializer.dump(
9797
handle_connector_builder_request(source, command, config, catalog, state, limits)
9898
)
99-
).decode()
99+
).decode() # type: ignore[no-any-return] # Serializer.dump() always returns AirbyteMessage
100100

101101

102102
if __name__ == "__main__":

airbyte_cdk/connector_builder/test_reader/helpers.py

Lines changed: 2 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import json
66
from copy import deepcopy
77
from json import JSONDecodeError
8-
from typing import Any, Dict, List, Mapping, Optional, Union
8+
from typing import Any, Dict, List, Mapping, Optional
99

1010
from airbyte_cdk.connector_builder.models import (
1111
AuxiliaryRequest,
@@ -17,8 +17,6 @@
1717
from airbyte_cdk.models import (
1818
AirbyteLogMessage,
1919
AirbyteMessage,
20-
AirbyteStateBlob,
21-
AirbyteStateMessage,
2220
OrchestratorType,
2321
TraceType,
2422
)
@@ -468,7 +466,7 @@ def handle_current_slice(
468466
return StreamReadSlices(
469467
pages=current_slice_pages,
470468
slice_descriptor=current_slice_descriptor,
471-
state=[convert_state_blob_to_mapping(latest_state_message)] if latest_state_message else [],
469+
state=[latest_state_message] if latest_state_message else [],
472470
auxiliary_requests=auxiliary_requests if auxiliary_requests else [],
473471
)
474472

@@ -720,23 +718,3 @@ def get_auxiliary_request_type(stream: dict, http: dict) -> str: # type: ignore
720718
Determines the type of the auxiliary request based on the stream and HTTP properties.
721719
"""
722720
return "PARENT_STREAM" if stream.get("is_substream", False) else str(http.get("type", None))
723-
724-
725-
def convert_state_blob_to_mapping(
726-
state_message: Union[AirbyteStateMessage, Dict[str, Any]],
727-
) -> Dict[str, Any]:
728-
"""
729-
The AirbyteStreamState stores state as an AirbyteStateBlob which deceivingly is not
730-
a dictionary, but rather a list of kwargs fields. This in turn causes it to not be
731-
properly turned into a dictionary when translating this back into response output
732-
by the connector_builder_handler using asdict()
733-
"""
734-
735-
if isinstance(state_message, AirbyteStateMessage) and state_message.stream:
736-
state_value = state_message.stream.stream_state
737-
if isinstance(state_value, AirbyteStateBlob):
738-
state_value_mapping = {k: v for k, v in state_value.__dict__.items()}
739-
state_message.stream.stream_state = state_value_mapping # type: ignore # we intentionally set this as a Dict so that StreamReadSlices is translated properly in the resulting HTTP response
740-
return state_message # type: ignore # See above, but when this is an AirbyteStateMessage we must convert AirbyteStateBlob to a Dict
741-
else:
742-
return state_message # type: ignore # This is guaranteed to be a Dict since we check isinstance AirbyteStateMessage above

airbyte_cdk/connector_builder/test_reader/message_grouper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def get_message_groups(
9595
latest_state_message: Optional[Dict[str, Any]] = None
9696
slice_auxiliary_requests: List[AuxiliaryRequest] = []
9797

98-
while message := next(messages, None):
98+
while records_count < limit and (message := next(messages, None)):
9999
json_message = airbyte_message_to_json(message)
100100

101101
if is_page_http_request_for_different_stream(json_message, stream_name):

airbyte_cdk/entrypoint.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
from airbyte_cdk.connector import TConfig
2424
from airbyte_cdk.exception_handler import init_uncaught_exception_handler
25-
from airbyte_cdk.logger import PRINT_BUFFER, init_logger
25+
from airbyte_cdk.logger import PRINT_BUFFER, init_logger, is_platform_debug_log_enabled
2626
from airbyte_cdk.models import (
2727
AirbyteConnectionStatus,
2828
AirbyteMessage,
@@ -158,7 +158,7 @@ def run(self, parsed_args: argparse.Namespace) -> Iterable[str]:
158158
if not cmd:
159159
raise Exception("No command passed")
160160

161-
if hasattr(parsed_args, "debug") and parsed_args.debug:
161+
if (hasattr(parsed_args, "debug") and parsed_args.debug) or is_platform_debug_log_enabled():
162162
self.logger.setLevel(logging.DEBUG)
163163
logger.setLevel(logging.DEBUG)
164164
self.logger.debug("Debug logs enabled")

airbyte_cdk/logger.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
#
22
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
33
#
4-
54
import json
65
import logging
76
import logging.config
7+
import os
88
from typing import Any, Callable, Mapping, Optional, Tuple
99

1010
import orjson
@@ -40,6 +40,10 @@
4040
}
4141

4242

43+
def is_platform_debug_log_enabled() -> bool:
44+
return os.environ.get("LOG_LEVEL", "info").lower() == "debug"
45+
46+
4347
def init_logger(name: Optional[str] = None) -> logging.Logger:
4448
"""Initial set up of logger"""
4549
logger = logging.getLogger(name)
@@ -73,8 +77,22 @@ def format(self, record: logging.LogRecord) -> str:
7377
airbyte_level = self.level_mapping.get(record.levelno, "INFO")
7478
if airbyte_level == Level.DEBUG:
7579
extras = self.extract_extra_args_from_record(record)
76-
debug_dict = {"type": "DEBUG", "message": record.getMessage(), "data": extras}
77-
return filter_secrets(json.dumps(debug_dict))
80+
if is_platform_debug_log_enabled():
81+
# We have a different behavior between debug logs enabled through `--debug` argument and debug logs
82+
# enabled through environment variable. The reason is that for platform logs, we need to have these
83+
# printed as AirbyteMessage which is not the case with the current previous implementation.
84+
# Why not migrate both to AirbyteMessages then? AirbyteMessages do not support having structured logs.
85+
# which means that the DX would be degraded compared to the current solution (devs will need to identify
86+
# the `log.message` field and figure out where in this field is the response while the current solution
87+
# have a specific field that is structured for extras.
88+
message = f"{filter_secrets(record.getMessage())} ///\nExtra logs: {filter_secrets(json.dumps(extras))}"
89+
log_message = AirbyteMessage(
90+
type=Type.LOG, log=AirbyteLogMessage(level=airbyte_level, message=message)
91+
)
92+
return orjson.dumps(AirbyteMessageSerializer.dump(log_message)).decode()
93+
else:
94+
debug_dict = {"type": "DEBUG", "message": record.getMessage(), "data": extras}
95+
return filter_secrets(json.dumps(debug_dict))
7896
else:
7997
message = super().format(record)
8098
message = filter_secrets(message)

airbyte_cdk/sources/concurrent_source/concurrent_read_processor.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -95,14 +95,11 @@ def on_partition(self, partition: Partition) -> None:
9595
"""
9696
stream_name = partition.stream_name()
9797
self._streams_to_running_partitions[stream_name].add(partition)
98-
cursor = self._stream_name_to_instance[stream_name].cursor
9998
if self._slice_logger.should_log_slice_message(self._logger):
10099
self._message_repository.emit_message(
101100
self._slice_logger.create_slice_log_message(partition.to_slice())
102101
)
103-
self._thread_pool_manager.submit(
104-
self._partition_reader.process_partition, partition, cursor
105-
)
102+
self._thread_pool_manager.submit(self._partition_reader.process_partition, partition)
106103

107104
def on_partition_complete_sentinel(
108105
self, sentinel: PartitionCompleteSentinel
@@ -115,16 +112,26 @@ def on_partition_complete_sentinel(
115112
"""
116113
partition = sentinel.partition
117114

118-
partitions_running = self._streams_to_running_partitions[partition.stream_name()]
119-
if partition in partitions_running:
120-
partitions_running.remove(partition)
121-
# If all partitions were generated and this was the last one, the stream is done
122-
if (
123-
partition.stream_name() not in self._streams_currently_generating_partitions
124-
and len(partitions_running) == 0
125-
):
126-
yield from self._on_stream_is_done(partition.stream_name())
127-
yield from self._message_repository.consume_queue()
115+
try:
116+
if sentinel.is_successful:
117+
stream = self._stream_name_to_instance[partition.stream_name()]
118+
stream.cursor.close_partition(partition)
119+
except Exception as exception:
120+
self._flag_exception(partition.stream_name(), exception)
121+
yield AirbyteTracedException.from_exception(
122+
exception, stream_descriptor=StreamDescriptor(name=partition.stream_name())
123+
).as_sanitized_airbyte_message()
124+
finally:
125+
partitions_running = self._streams_to_running_partitions[partition.stream_name()]
126+
if partition in partitions_running:
127+
partitions_running.remove(partition)
128+
# If all partitions were generated and this was the last one, the stream is done
129+
if (
130+
partition.stream_name() not in self._streams_currently_generating_partitions
131+
and len(partitions_running) == 0
132+
):
133+
yield from self._on_stream_is_done(partition.stream_name())
134+
yield from self._message_repository.consume_queue()
128135

129136
def on_record(self, record: Record) -> Iterable[AirbyteMessage]:
130137
"""

0 commit comments

Comments
 (0)