Skip to content

Commit 11ad52f

Browse files
committed
Add more integration tests for kafka loader
1 parent 30ef80f commit 11ad52f

File tree

3 files changed

+223
-1
lines changed

3 files changed

+223
-1
lines changed

src/amp/loaders/implementations/kafka_loader.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,12 +105,17 @@ def _extract_message_key(self, row: Dict[str, Any]) -> Optional[bytes]:
105105

106106
return str(key_value).encode('utf-8')
107107

108-
def _handle_reorg(self, invalidation_ranges: List[BlockRange], table_name: str) -> None:
108+
def _handle_reorg(self, invalidation_ranges: List[BlockRange], table_name: str, connection_name: str) -> None:
109109
"""
110110
Handle blockchain reorganization by sending reorg events to the same topic.
111111
112112
Reorg events are sent as special messages with _type='reorg' so consumers
113113
can detect and handle invalidated block ranges.
114+
115+
Args:
116+
invalidation_ranges: List of block ranges to invalidate
117+
table_name: The Kafka topic name to send reorg events to
118+
connection_name: Connection identifier (unused for Kafka, but required by base class)
114119
"""
115120
if not invalidation_ranges:
116121
return

tests/conftest.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,12 @@ def kafka_container():
210210
if not TESTCONTAINERS_AVAILABLE:
211211
pytest.skip('Testcontainers not available')
212212

213+
# Configure Kafka for transactions in single-broker setup
214+
# These settings are required for transactional producers to work
213215
container = KafkaContainer(image='confluentinc/cp-kafka:7.6.0')
216+
container.with_env('KAFKA_TRANSACTION_STATE_LOG_REPLICATION_FACTOR', '1')
217+
container.with_env('KAFKA_TRANSACTION_STATE_LOG_MIN_ISR', '1')
218+
container.with_env('KAFKA_OFFSETS_TOPIC_REPLICATION_FACTOR', '1')
214219
container.start()
215220

216221
time.sleep(10)

tests/integration/test_kafka_loader.py

Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
1+
import json
2+
from unittest.mock import patch
3+
14
import pyarrow as pa
25
import pytest
6+
from kafka import KafkaConsumer
7+
from kafka.errors import KafkaError
38

49
try:
510
from src.amp.loaders.implementations.kafka_loader import KafkaLoader
11+
from src.amp.streaming.types import BatchMetadata, BlockRange, ResponseBatch
612
except ImportError:
713
pytest.skip('amp modules not available', allow_module_level=True)
814

@@ -42,3 +48,209 @@ def test_load_batch(self, kafka_test_config):
4248

4349
assert result.success == True
4450
assert result.rows_loaded == 3
51+
52+
def test_message_consumption_verification(self, kafka_test_config):
53+
loader = KafkaLoader(kafka_test_config)
54+
topic_name = 'test_consumption_topic'
55+
56+
batch = pa.RecordBatch.from_pydict(
57+
{
58+
'id': [1, 2, 3],
59+
'name': ['alice', 'bob', 'charlie'],
60+
'score': [100, 200, 150],
61+
'active': [True, False, True],
62+
}
63+
)
64+
65+
with loader:
66+
result = loader.load_batch(batch, topic_name)
67+
68+
assert result.success is True
69+
assert result.rows_loaded == 3
70+
71+
consumer = KafkaConsumer(
72+
topic_name,
73+
bootstrap_servers=kafka_test_config['bootstrap_servers'],
74+
auto_offset_reset='earliest',
75+
consumer_timeout_ms=5000,
76+
value_deserializer=lambda m: json.loads(m.decode('utf-8')),
77+
)
78+
79+
messages = list(consumer)
80+
consumer.close()
81+
82+
assert len(messages) == 3
83+
84+
for i, msg in enumerate(messages):
85+
assert msg.key == str(i + 1).encode('utf-8')
86+
assert msg.value['_type'] == 'data'
87+
assert msg.value['id'] == i + 1
88+
assert msg.value['name'] in ['alice', 'bob', 'charlie']
89+
assert msg.value['score'] in [100, 200, 150]
90+
assert msg.value['active'] in [True, False]
91+
92+
msg1 = messages[0]
93+
assert msg1.value['id'] == 1
94+
assert msg1.value['name'] == 'alice'
95+
assert msg1.value['score'] == 100
96+
assert msg1.value['active'] is True
97+
98+
msg2 = messages[1]
99+
assert msg2.value['id'] == 2
100+
assert msg2.value['name'] == 'bob'
101+
assert msg2.value['score'] == 200
102+
assert msg2.value['active'] is False
103+
104+
msg3 = messages[2]
105+
assert msg3.value['id'] == 3
106+
assert msg3.value['name'] == 'charlie'
107+
assert msg3.value['score'] == 150
108+
assert msg3.value['active'] is True
109+
110+
def test_handle_reorg(self, kafka_test_config):
111+
loader = KafkaLoader(kafka_test_config)
112+
topic_name = 'test_reorg_topic'
113+
114+
invalidation_ranges = [
115+
BlockRange(network='ethereum', start=100, end=200),
116+
BlockRange(network='polygon', start=500, end=600),
117+
]
118+
119+
with loader:
120+
loader._handle_reorg(invalidation_ranges, topic_name, 'test_connection')
121+
122+
consumer = KafkaConsumer(
123+
topic_name,
124+
bootstrap_servers=kafka_test_config['bootstrap_servers'],
125+
auto_offset_reset='earliest',
126+
consumer_timeout_ms=5000,
127+
value_deserializer=lambda m: json.loads(m.decode('utf-8')),
128+
)
129+
130+
messages = list(consumer)
131+
consumer.close()
132+
133+
assert len(messages) == 2
134+
135+
msg1 = messages[0]
136+
assert msg1.key == b'reorg:ethereum'
137+
assert msg1.value['_type'] == 'reorg'
138+
assert msg1.value['network'] == 'ethereum'
139+
assert msg1.value['start_block'] == 100
140+
assert msg1.value['end_block'] == 200
141+
142+
msg2 = messages[1]
143+
assert msg2.key == b'reorg:polygon'
144+
assert msg2.value['_type'] == 'reorg'
145+
assert msg2.value['network'] == 'polygon'
146+
assert msg2.value['start_block'] == 500
147+
assert msg2.value['end_block'] == 600
148+
149+
def test_streaming_with_reorg(self, kafka_test_config):
150+
loader = KafkaLoader(kafka_test_config)
151+
topic_name = 'test_streaming_topic'
152+
153+
data1 = pa.RecordBatch.from_pydict({'id': [1, 2], 'value': [100, 200]})
154+
data2 = pa.RecordBatch.from_pydict({'id': [3, 4], 'value': [300, 400]})
155+
data3 = pa.RecordBatch.from_pydict({'id': [5, 6], 'value': [500, 600]})
156+
157+
response1 = ResponseBatch.data_batch(
158+
data=data1,
159+
metadata=BatchMetadata(ranges=[BlockRange(network='ethereum', start=100, end=110, hash='0xabc123')]),
160+
)
161+
162+
response2 = ResponseBatch.data_batch(
163+
data=data2,
164+
metadata=BatchMetadata(ranges=[BlockRange(network='ethereum', start=110, end=120, hash='0xdef456')]),
165+
)
166+
167+
reorg_response = ResponseBatch.reorg_batch(
168+
invalidation_ranges=[BlockRange(network='ethereum', start=110, end=200)]
169+
)
170+
171+
response3 = ResponseBatch.data_batch(
172+
data=data3,
173+
metadata=BatchMetadata(ranges=[BlockRange(network='ethereum', start=110, end=120, hash='0xnew123')]),
174+
)
175+
176+
stream = [response1, response2, reorg_response, response3]
177+
178+
with loader:
179+
results = list(loader.load_stream_continuous(iter(stream), topic_name))
180+
181+
assert len(results) == 4
182+
assert results[0].success
183+
assert results[0].rows_loaded == 2
184+
assert results[1].success
185+
assert results[1].rows_loaded == 2
186+
assert results[2].success
187+
assert results[2].is_reorg
188+
assert results[3].success
189+
assert results[3].rows_loaded == 2
190+
191+
consumer = KafkaConsumer(
192+
topic_name,
193+
bootstrap_servers=kafka_test_config['bootstrap_servers'],
194+
auto_offset_reset='earliest',
195+
consumer_timeout_ms=5000,
196+
value_deserializer=lambda m: json.loads(m.decode('utf-8')),
197+
)
198+
199+
messages = list(consumer)
200+
consumer.close()
201+
202+
assert len(messages) == 7
203+
204+
data_messages = [msg for msg in messages if msg.value.get('_type') == 'data']
205+
reorg_messages = [msg for msg in messages if msg.value.get('_type') == 'reorg']
206+
207+
assert len(data_messages) == 6
208+
assert len(reorg_messages) == 1
209+
210+
assert reorg_messages[0].key == b'reorg:ethereum'
211+
assert reorg_messages[0].value['network'] == 'ethereum'
212+
assert reorg_messages[0].value['start_block'] == 110
213+
assert reorg_messages[0].value['end_block'] == 200
214+
215+
data_ids = [msg.value['id'] for msg in data_messages]
216+
assert data_ids == [1, 2, 3, 4, 5, 6]
217+
218+
def test_transaction_rollback_on_error(self, kafka_test_config):
219+
loader = KafkaLoader(kafka_test_config)
220+
topic_name = 'test_rollback_topic'
221+
222+
batch = pa.RecordBatch.from_pydict(
223+
{
224+
'id': [1, 2, 3, 4, 5],
225+
'name': ['alice', 'bob', 'charlie', 'dave', 'eve'],
226+
'value': [100, 200, 300, 400, 500],
227+
}
228+
)
229+
230+
with loader:
231+
call_count = [0]
232+
233+
original_send = loader._producer.send
234+
235+
def failing_send(*args, **kwargs):
236+
call_count[0] += 1
237+
if call_count[0] == 3:
238+
raise KafkaError('Simulated Kafka send failure')
239+
return original_send(*args, **kwargs)
240+
241+
with patch.object(loader._producer, 'send', side_effect=failing_send):
242+
with pytest.raises(RuntimeError, match='FATAL: Permanent error loading batch'):
243+
loader.load_batch(batch, topic_name)
244+
245+
consumer = KafkaConsumer(
246+
topic_name,
247+
bootstrap_servers=kafka_test_config['bootstrap_servers'],
248+
auto_offset_reset='earliest',
249+
consumer_timeout_ms=5000,
250+
value_deserializer=lambda m: json.loads(m.decode('utf-8')),
251+
)
252+
253+
messages = list(consumer)
254+
consumer.close()
255+
256+
assert len(messages) == 0

0 commit comments

Comments
 (0)