Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions airbyte_cdk/sources/streams/concurrent/adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,12 @@ def read(self) -> Iterable[Record]:
stream_name=self.stream_name(),
associated_slice=self._slice, # type: ignore [arg-type]
)
elif isinstance(record_data, AirbyteMessage) and record_data.record is not None:
yield Record(
data=record_data.record.data or {},
stream_name=self.stream_name(),
associated_slice=self._slice, # type: ignore [arg-type]
)
else:
self._message_repository.emit_message(record_data)
except Exception as e:
Expand Down
64 changes: 63 additions & 1 deletion unit_tests/sources/streams/concurrent/test_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,14 @@

import pytest

from airbyte_cdk.models import AirbyteLogMessage, AirbyteMessage, AirbyteStream, Level, SyncMode
from airbyte_cdk.models import (
AirbyteLogMessage,
AirbyteMessage,
AirbyteRecordMessage,
AirbyteStream,
Level,
SyncMode,
)
from airbyte_cdk.models import Type as MessageType
from airbyte_cdk.sources.message import InMemoryMessageRepository
from airbyte_cdk.sources.streams.concurrent.adapters import (
Expand Down Expand Up @@ -132,6 +139,61 @@ def test_stream_partition(transformer, expected_records):
assert messages == [a_log_message]


@pytest.mark.parametrize(
"transformer, expected_records",
[
pytest.param(
TypeTransformer(TransformConfig.NoTransform),
[Record({"data": "1"}, None), Record({"data": "2"}, None)],
id="test_no_transform",
),
],
)
def test_stream_partition_read_airbyte_message(transformer, expected_records):
stream = Mock()
stream.name = _STREAM_NAME
stream.get_json_schema.return_value = {
"type": "object",
"properties": {"data": {"type": ["integer"]}},
}
stream.transformer = transformer
message_repository = InMemoryMessageRepository()
_slice = None
sync_mode = SyncMode.full_refresh
cursor_field = None
state = None
partition = StreamPartition(stream, _slice, message_repository, sync_mode, cursor_field, state)

a_log_message = AirbyteMessage(
type=MessageType.LOG,
log=AirbyteLogMessage(
level=Level.INFO,
message='slice:{"partition": 1}',
),
)
for record in expected_records:
record.partition = partition

stream_data = [
a_log_message,
AirbyteMessage(
type=MessageType.RECORD,
record=AirbyteRecordMessage(stream=stream.name, data={"data": "1"}, emitted_at=1),
),
AirbyteMessage(
type=MessageType.RECORD,
record=AirbyteRecordMessage(stream=stream.name, data={"data": "2"}, emitted_at=2),
),
]
stream.read_records.return_value = stream_data

records = list(partition.read())
messages = list(message_repository.consume_queue())

assert records == expected_records
assert messages == [a_log_message]


@pytest.mark.parametrize(
"exception_type, expected_display_message",
[
Expand Down
Loading