Skip to content

Commit edafc40

Browse files
authored
fix(concurrent-cdk): StreamPartition handles AirbyteRecords (#392)
1 parent 4f9fd20 commit edafc40

File tree

2 files changed

+72
-2
lines changed

2 files changed

+72
-2
lines changed

airbyte_cdk/sources/streams/concurrent/adapters.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ def __init__(
276276
def read(self) -> Iterable[Record]:
277277
"""
278278
Read messages from the stream.
279-
If the StreamData is a Mapping, it will be converted to a Record.
279+
If the StreamData is a Mapping or an AirbyteMessage of type RECORD, it will be converted to a Record.
280280
Otherwise, the message will be emitted on the message repository.
281281
"""
282282
try:
@@ -292,6 +292,8 @@ def read(self) -> Iterable[Record]:
292292
stream_slice=copy.deepcopy(self._slice),
293293
stream_state=self._state,
294294
):
295+
# Noting we'll also need to support FileTransferRecordMessage if we want to support file-based connectors in this facade
296+
# For now, file-based connectors have their own stream facade
295297
if isinstance(record_data, Mapping):
296298
data_to_return = dict(record_data)
297299
self._stream.transformer.transform(
@@ -302,6 +304,12 @@ def read(self) -> Iterable[Record]:
302304
stream_name=self.stream_name(),
303305
associated_slice=self._slice, # type: ignore [arg-type]
304306
)
307+
elif isinstance(record_data, AirbyteMessage) and record_data.record is not None:
308+
yield Record(
309+
data=record_data.record.data or {},
310+
stream_name=self.stream_name(),
311+
associated_slice=self._slice, # type: ignore [arg-type]
312+
)
305313
else:
306314
self._message_repository.emit_message(record_data)
307315
except Exception as e:

unit_tests/sources/streams/concurrent/test_adapters.py

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,14 @@
77

88
import pytest
99

10-
from airbyte_cdk.models import AirbyteLogMessage, AirbyteMessage, AirbyteStream, Level, SyncMode
10+
from airbyte_cdk.models import (
11+
AirbyteLogMessage,
12+
AirbyteMessage,
13+
AirbyteRecordMessage,
14+
AirbyteStream,
15+
Level,
16+
SyncMode,
17+
)
1118
from airbyte_cdk.models import Type as MessageType
1219
from airbyte_cdk.sources.message import InMemoryMessageRepository
1320
from airbyte_cdk.sources.streams.concurrent.adapters import (
@@ -132,6 +139,61 @@ def test_stream_partition(transformer, expected_records):
132139
assert messages == [a_log_message]
133140

134141

142+
@pytest.mark.parametrize(
143+
"transformer, expected_records",
144+
[
145+
pytest.param(
146+
TypeTransformer(TransformConfig.NoTransform),
147+
[Record({"data": "1"}, None), Record({"data": "2"}, None)],
148+
id="test_no_transform",
149+
),
150+
],
151+
)
152+
def test_stream_partition_read_airbyte_message(transformer, expected_records):
153+
stream = Mock()
154+
stream.name = _STREAM_NAME
155+
stream.get_json_schema.return_value = {
156+
"type": "object",
157+
"properties": {"data": {"type": ["integer"]}},
158+
}
159+
stream.transformer = transformer
160+
message_repository = InMemoryMessageRepository()
161+
_slice = None
162+
sync_mode = SyncMode.full_refresh
163+
cursor_field = None
164+
state = None
165+
partition = StreamPartition(stream, _slice, message_repository, sync_mode, cursor_field, state)
166+
167+
a_log_message = AirbyteMessage(
168+
type=MessageType.LOG,
169+
log=AirbyteLogMessage(
170+
level=Level.INFO,
171+
message='slice:{"partition": 1}',
172+
),
173+
)
174+
for record in expected_records:
175+
record.partition = partition
176+
177+
stream_data = [
178+
a_log_message,
179+
AirbyteMessage(
180+
type=MessageType.RECORD,
181+
record=AirbyteRecordMessage(stream=stream.name, data={"data": "1"}, emitted_at=1),
182+
),
183+
AirbyteMessage(
184+
type=MessageType.RECORD,
185+
record=AirbyteRecordMessage(stream=stream.name, data={"data": "2"}, emitted_at=2),
186+
),
187+
]
188+
stream.read_records.return_value = stream_data
189+
190+
records = list(partition.read())
191+
messages = list(message_repository.consume_queue())
192+
193+
assert records == expected_records
194+
assert messages == [a_log_message]
195+
196+
135197
@pytest.mark.parametrize(
136198
"exception_type, expected_display_message",
137199
[

0 commit comments

Comments
 (0)