|
7 | 7 |
|
8 | 8 | import pytest |
9 | 9 |
|
10 | | -from airbyte_cdk.models import AirbyteLogMessage, AirbyteMessage, AirbyteStream, Level, SyncMode |
| 10 | +from airbyte_cdk.models import ( |
| 11 | + AirbyteLogMessage, |
| 12 | + AirbyteMessage, |
| 13 | + AirbyteStream, |
| 14 | + Level, |
| 15 | + SyncMode, |
| 16 | + AirbyteRecordMessage, |
| 17 | +) |
11 | 18 | from airbyte_cdk.models import Type as MessageType |
12 | 19 | from airbyte_cdk.sources.message import InMemoryMessageRepository |
13 | 20 | from airbyte_cdk.sources.streams.concurrent.adapters import ( |
@@ -132,6 +139,61 @@ def test_stream_partition(transformer, expected_records): |
132 | 139 | assert messages == [a_log_message] |
133 | 140 |
|
134 | 141 |
|
| 142 | +@pytest.mark.parametrize( |
| 143 | + "transformer, expected_records", |
| 144 | + [ |
| 145 | + pytest.param( |
| 146 | + TypeTransformer(TransformConfig.NoTransform), |
| 147 | + [Record({"data": "1"}, None), Record({"data": "2"}, None)], |
| 148 | + id="test_no_transform", |
| 149 | + ), |
| 150 | + ], |
| 151 | +) |
| 152 | +def test_stream_partition_read_airbyte_message(transformer, expected_records): |
| 153 | + stream = Mock() |
| 154 | + stream.name = _STREAM_NAME |
| 155 | + stream.get_json_schema.return_value = { |
| 156 | + "type": "object", |
| 157 | + "properties": {"data": {"type": ["integer"]}}, |
| 158 | + } |
| 159 | + stream.transformer = transformer |
| 160 | + message_repository = InMemoryMessageRepository() |
| 161 | + _slice = None |
| 162 | + sync_mode = SyncMode.full_refresh |
| 163 | + cursor_field = None |
| 164 | + state = None |
| 165 | + partition = StreamPartition(stream, _slice, message_repository, sync_mode, cursor_field, state) |
| 166 | + |
| 167 | + a_log_message = AirbyteMessage( |
| 168 | + type=MessageType.LOG, |
| 169 | + log=AirbyteLogMessage( |
| 170 | + level=Level.INFO, |
| 171 | + message='slice:{"partition": 1}', |
| 172 | + ), |
| 173 | + ) |
| 174 | + for record in expected_records: |
| 175 | + record.partition = partition |
| 176 | + |
| 177 | + stream_data = [ |
| 178 | + a_log_message, |
| 179 | + AirbyteMessage( |
| 180 | + type=MessageType.RECORD, |
| 181 | + record=AirbyteRecordMessage(stream=stream.name, data={"data": "1"}), |
| 182 | + ), |
| 183 | + AirbyteMessage( |
| 184 | + type=MessageType.RECORD, |
| 185 | + record=AirbyteRecordMessage(stream=stream.name, data={"data": "2"}), |
| 186 | + ), |
| 187 | + ] |
| 188 | + stream.read_records.return_value = stream_data |
| 189 | + |
| 190 | + records = list(partition.read()) |
| 191 | + messages = list(message_repository.consume_queue()) |
| 192 | + |
| 193 | + assert records == expected_records |
| 194 | + assert messages == [a_log_message] |
| 195 | + |
| 196 | + |
135 | 197 | @pytest.mark.parametrize( |
136 | 198 | "exception_type, expected_display_message", |
137 | 199 | [ |
|
0 commit comments