|
| 1 | +import json |
| 2 | +from unittest.mock import patch |
| 3 | + |
1 | 4 | import pyarrow as pa |
2 | 5 | import pytest |
| 6 | +from kafka import KafkaConsumer |
| 7 | +from kafka.errors import KafkaError |
3 | 8 |
|
4 | 9 | try: |
5 | 10 | from src.amp.loaders.implementations.kafka_loader import KafkaLoader |
| 11 | + from src.amp.streaming.types import BatchMetadata, BlockRange, ResponseBatch |
6 | 12 | except ImportError: |
7 | 13 | pytest.skip('amp modules not available', allow_module_level=True) |
8 | 14 |
|
@@ -42,3 +48,209 @@ def test_load_batch(self, kafka_test_config): |
42 | 48 |
|
43 | 49 | assert result.success == True |
44 | 50 | assert result.rows_loaded == 3 |
| 51 | + |
| 52 | + def test_message_consumption_verification(self, kafka_test_config): |
| 53 | + loader = KafkaLoader(kafka_test_config) |
| 54 | + topic_name = 'test_consumption_topic' |
| 55 | + |
| 56 | + batch = pa.RecordBatch.from_pydict( |
| 57 | + { |
| 58 | + 'id': [1, 2, 3], |
| 59 | + 'name': ['alice', 'bob', 'charlie'], |
| 60 | + 'score': [100, 200, 150], |
| 61 | + 'active': [True, False, True], |
| 62 | + } |
| 63 | + ) |
| 64 | + |
| 65 | + with loader: |
| 66 | + result = loader.load_batch(batch, topic_name) |
| 67 | + |
| 68 | + assert result.success is True |
| 69 | + assert result.rows_loaded == 3 |
| 70 | + |
| 71 | + consumer = KafkaConsumer( |
| 72 | + topic_name, |
| 73 | + bootstrap_servers=kafka_test_config['bootstrap_servers'], |
| 74 | + auto_offset_reset='earliest', |
| 75 | + consumer_timeout_ms=5000, |
| 76 | + value_deserializer=lambda m: json.loads(m.decode('utf-8')), |
| 77 | + ) |
| 78 | + |
| 79 | + messages = list(consumer) |
| 80 | + consumer.close() |
| 81 | + |
| 82 | + assert len(messages) == 3 |
| 83 | + |
| 84 | + for i, msg in enumerate(messages): |
| 85 | + assert msg.key == str(i + 1).encode('utf-8') |
| 86 | + assert msg.value['_type'] == 'data' |
| 87 | + assert msg.value['id'] == i + 1 |
| 88 | + assert msg.value['name'] in ['alice', 'bob', 'charlie'] |
| 89 | + assert msg.value['score'] in [100, 200, 150] |
| 90 | + assert msg.value['active'] in [True, False] |
| 91 | + |
| 92 | + msg1 = messages[0] |
| 93 | + assert msg1.value['id'] == 1 |
| 94 | + assert msg1.value['name'] == 'alice' |
| 95 | + assert msg1.value['score'] == 100 |
| 96 | + assert msg1.value['active'] is True |
| 97 | + |
| 98 | + msg2 = messages[1] |
| 99 | + assert msg2.value['id'] == 2 |
| 100 | + assert msg2.value['name'] == 'bob' |
| 101 | + assert msg2.value['score'] == 200 |
| 102 | + assert msg2.value['active'] is False |
| 103 | + |
| 104 | + msg3 = messages[2] |
| 105 | + assert msg3.value['id'] == 3 |
| 106 | + assert msg3.value['name'] == 'charlie' |
| 107 | + assert msg3.value['score'] == 150 |
| 108 | + assert msg3.value['active'] is True |
| 109 | + |
| 110 | + def test_handle_reorg(self, kafka_test_config): |
| 111 | + loader = KafkaLoader(kafka_test_config) |
| 112 | + topic_name = 'test_reorg_topic' |
| 113 | + |
| 114 | + invalidation_ranges = [ |
| 115 | + BlockRange(network='ethereum', start=100, end=200), |
| 116 | + BlockRange(network='polygon', start=500, end=600), |
| 117 | + ] |
| 118 | + |
| 119 | + with loader: |
| 120 | + loader._handle_reorg(invalidation_ranges, topic_name, 'test_connection') |
| 121 | + |
| 122 | + consumer = KafkaConsumer( |
| 123 | + topic_name, |
| 124 | + bootstrap_servers=kafka_test_config['bootstrap_servers'], |
| 125 | + auto_offset_reset='earliest', |
| 126 | + consumer_timeout_ms=5000, |
| 127 | + value_deserializer=lambda m: json.loads(m.decode('utf-8')), |
| 128 | + ) |
| 129 | + |
| 130 | + messages = list(consumer) |
| 131 | + consumer.close() |
| 132 | + |
| 133 | + assert len(messages) == 2 |
| 134 | + |
| 135 | + msg1 = messages[0] |
| 136 | + assert msg1.key == b'reorg:ethereum' |
| 137 | + assert msg1.value['_type'] == 'reorg' |
| 138 | + assert msg1.value['network'] == 'ethereum' |
| 139 | + assert msg1.value['start_block'] == 100 |
| 140 | + assert msg1.value['end_block'] == 200 |
| 141 | + |
| 142 | + msg2 = messages[1] |
| 143 | + assert msg2.key == b'reorg:polygon' |
| 144 | + assert msg2.value['_type'] == 'reorg' |
| 145 | + assert msg2.value['network'] == 'polygon' |
| 146 | + assert msg2.value['start_block'] == 500 |
| 147 | + assert msg2.value['end_block'] == 600 |
| 148 | + |
| 149 | + def test_streaming_with_reorg(self, kafka_test_config): |
| 150 | + loader = KafkaLoader(kafka_test_config) |
| 151 | + topic_name = 'test_streaming_topic' |
| 152 | + |
| 153 | + data1 = pa.RecordBatch.from_pydict({'id': [1, 2], 'value': [100, 200]}) |
| 154 | + data2 = pa.RecordBatch.from_pydict({'id': [3, 4], 'value': [300, 400]}) |
| 155 | + data3 = pa.RecordBatch.from_pydict({'id': [5, 6], 'value': [500, 600]}) |
| 156 | + |
| 157 | + response1 = ResponseBatch.data_batch( |
| 158 | + data=data1, |
| 159 | + metadata=BatchMetadata(ranges=[BlockRange(network='ethereum', start=100, end=110, hash='0xabc123')]), |
| 160 | + ) |
| 161 | + |
| 162 | + response2 = ResponseBatch.data_batch( |
| 163 | + data=data2, |
| 164 | + metadata=BatchMetadata(ranges=[BlockRange(network='ethereum', start=110, end=120, hash='0xdef456')]), |
| 165 | + ) |
| 166 | + |
| 167 | + reorg_response = ResponseBatch.reorg_batch( |
| 168 | + invalidation_ranges=[BlockRange(network='ethereum', start=110, end=200)] |
| 169 | + ) |
| 170 | + |
| 171 | + response3 = ResponseBatch.data_batch( |
| 172 | + data=data3, |
| 173 | + metadata=BatchMetadata(ranges=[BlockRange(network='ethereum', start=110, end=120, hash='0xnew123')]), |
| 174 | + ) |
| 175 | + |
| 176 | + stream = [response1, response2, reorg_response, response3] |
| 177 | + |
| 178 | + with loader: |
| 179 | + results = list(loader.load_stream_continuous(iter(stream), topic_name)) |
| 180 | + |
| 181 | + assert len(results) == 4 |
| 182 | + assert results[0].success |
| 183 | + assert results[0].rows_loaded == 2 |
| 184 | + assert results[1].success |
| 185 | + assert results[1].rows_loaded == 2 |
| 186 | + assert results[2].success |
| 187 | + assert results[2].is_reorg |
| 188 | + assert results[3].success |
| 189 | + assert results[3].rows_loaded == 2 |
| 190 | + |
| 191 | + consumer = KafkaConsumer( |
| 192 | + topic_name, |
| 193 | + bootstrap_servers=kafka_test_config['bootstrap_servers'], |
| 194 | + auto_offset_reset='earliest', |
| 195 | + consumer_timeout_ms=5000, |
| 196 | + value_deserializer=lambda m: json.loads(m.decode('utf-8')), |
| 197 | + ) |
| 198 | + |
| 199 | + messages = list(consumer) |
| 200 | + consumer.close() |
| 201 | + |
| 202 | + assert len(messages) == 7 |
| 203 | + |
| 204 | + data_messages = [msg for msg in messages if msg.value.get('_type') == 'data'] |
| 205 | + reorg_messages = [msg for msg in messages if msg.value.get('_type') == 'reorg'] |
| 206 | + |
| 207 | + assert len(data_messages) == 6 |
| 208 | + assert len(reorg_messages) == 1 |
| 209 | + |
| 210 | + assert reorg_messages[0].key == b'reorg:ethereum' |
| 211 | + assert reorg_messages[0].value['network'] == 'ethereum' |
| 212 | + assert reorg_messages[0].value['start_block'] == 110 |
| 213 | + assert reorg_messages[0].value['end_block'] == 200 |
| 214 | + |
| 215 | + data_ids = [msg.value['id'] for msg in data_messages] |
| 216 | + assert data_ids == [1, 2, 3, 4, 5, 6] |
| 217 | + |
| 218 | + def test_transaction_rollback_on_error(self, kafka_test_config): |
| 219 | + loader = KafkaLoader(kafka_test_config) |
| 220 | + topic_name = 'test_rollback_topic' |
| 221 | + |
| 222 | + batch = pa.RecordBatch.from_pydict( |
| 223 | + { |
| 224 | + 'id': [1, 2, 3, 4, 5], |
| 225 | + 'name': ['alice', 'bob', 'charlie', 'dave', 'eve'], |
| 226 | + 'value': [100, 200, 300, 400, 500], |
| 227 | + } |
| 228 | + ) |
| 229 | + |
| 230 | + with loader: |
| 231 | + call_count = [0] |
| 232 | + |
| 233 | + original_send = loader._producer.send |
| 234 | + |
| 235 | + def failing_send(*args, **kwargs): |
| 236 | + call_count[0] += 1 |
| 237 | + if call_count[0] == 3: |
| 238 | + raise KafkaError('Simulated Kafka send failure') |
| 239 | + return original_send(*args, **kwargs) |
| 240 | + |
| 241 | + with patch.object(loader._producer, 'send', side_effect=failing_send): |
| 242 | + with pytest.raises(RuntimeError, match='FATAL: Permanent error loading batch'): |
| 243 | + loader.load_batch(batch, topic_name) |
| 244 | + |
| 245 | + consumer = KafkaConsumer( |
| 246 | + topic_name, |
| 247 | + bootstrap_servers=kafka_test_config['bootstrap_servers'], |
| 248 | + auto_offset_reset='earliest', |
| 249 | + consumer_timeout_ms=5000, |
| 250 | + value_deserializer=lambda m: json.loads(m.decode('utf-8')), |
| 251 | + ) |
| 252 | + |
| 253 | + messages = list(consumer) |
| 254 | + consumer.close() |
| 255 | + |
| 256 | + assert len(messages) == 0 |
0 commit comments