Skip to content

Commit 9abc5fd

Browse files
committed
Add a test
1 parent 01f1eed commit 9abc5fd

File tree

1 file changed

+63
-1
lines changed

1 file changed

+63
-1
lines changed

unit_tests/sources/streams/concurrent/test_adapters.py

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,14 @@
77

88
import pytest
99

10-
from airbyte_cdk.models import AirbyteLogMessage, AirbyteMessage, AirbyteStream, Level, SyncMode
10+
from airbyte_cdk.models import (
11+
AirbyteLogMessage,
12+
AirbyteMessage,
13+
AirbyteStream,
14+
Level,
15+
SyncMode,
16+
AirbyteRecordMessage,
17+
)
1118
from airbyte_cdk.models import Type as MessageType
1219
from airbyte_cdk.sources.message import InMemoryMessageRepository
1320
from airbyte_cdk.sources.streams.concurrent.adapters import (
@@ -132,6 +139,61 @@ def test_stream_partition(transformer, expected_records):
132139
assert messages == [a_log_message]
133140

134141

142+
@pytest.mark.parametrize(
143+
"transformer, expected_records",
144+
[
145+
pytest.param(
146+
TypeTransformer(TransformConfig.NoTransform),
147+
[Record({"data": "1"}, None), Record({"data": "2"}, None)],
148+
id="test_no_transform",
149+
),
150+
],
151+
)
152+
def test_stream_partition_read_airbyte_message(transformer, expected_records):
153+
stream = Mock()
154+
stream.name = _STREAM_NAME
155+
stream.get_json_schema.return_value = {
156+
"type": "object",
157+
"properties": {"data": {"type": ["integer"]}},
158+
}
159+
stream.transformer = transformer
160+
message_repository = InMemoryMessageRepository()
161+
_slice = None
162+
sync_mode = SyncMode.full_refresh
163+
cursor_field = None
164+
state = None
165+
partition = StreamPartition(stream, _slice, message_repository, sync_mode, cursor_field, state)
166+
167+
a_log_message = AirbyteMessage(
168+
type=MessageType.LOG,
169+
log=AirbyteLogMessage(
170+
level=Level.INFO,
171+
message='slice:{"partition": 1}',
172+
),
173+
)
174+
for record in expected_records:
175+
record.partition = partition
176+
177+
stream_data = [
178+
a_log_message,
179+
AirbyteMessage(
180+
type=MessageType.RECORD,
181+
record=AirbyteRecordMessage(stream=stream.name, data={"data": "1"}),
182+
),
183+
AirbyteMessage(
184+
type=MessageType.RECORD,
185+
record=AirbyteRecordMessage(stream=stream.name, data={"data": "2"}),
186+
),
187+
]
188+
stream.read_records.return_value = stream_data
189+
190+
records = list(partition.read())
191+
messages = list(message_repository.consume_queue())
192+
193+
assert records == expected_records
194+
assert messages == [a_log_message]
195+
196+
135197
@pytest.mark.parametrize(
136198
"exception_type, expected_display_message",
137199
[

0 commit comments

Comments
 (0)