Skip to content

Commit 2cfe41b

Browse files
committed
Better crash recovery with state store
1 parent 8d45a01 commit 2cfe41b

File tree

2 files changed

+208
-1
lines changed

2 files changed

+208
-1
lines changed

src/amp/loaders/base.py

Lines changed: 93 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,11 +77,19 @@ def __init__(self, config: Dict[str, Any], label_manager=None) -> None:
7777
else:
7878
self.state_store = NullStreamStateStore()
7979

80+
# Track tables that have undergone crash recovery
81+
self._crash_recovery_done: set[str] = set()
82+
8083
@property
8184
def is_connected(self) -> bool:
8285
"""Check if the loader is connected to the target system."""
8386
return self._is_connected
8487

88+
@property
89+
def loader_type(self) -> str:
90+
"""Get the loader type identifier (e.g., 'postgresql', 'redis')."""
91+
return self.__class__.__name__.replace('Loader', '').lower()
92+
8593
def _parse_config(self, config: Dict[str, Any]) -> TConfig:
8694
"""
8795
Parse configuration into loader-specific format.
@@ -446,11 +454,21 @@ def load_stream_continuous(
446454
if not self._is_connected:
447455
self.connect()
448456

457+
connection_name = kwargs.get('connection_name')
458+
if connection_name is None:
459+
connection_name = self.loader_type
460+
461+
if table_name not in self._crash_recovery_done:
462+
self.logger.info(f'Running crash recovery for table {table_name} (connection: {connection_name})')
463+
self._rewind_to_watermark(table_name, connection_name)
464+
self._crash_recovery_done.add(table_name)
465+
else:
466+
self.logger.info(f'Crash recovery already done for table {table_name}')
467+
449468
rows_loaded = 0
450469
start_time = time.time()
451470
batch_count = 0
452471
reorg_count = 0
453-
connection_name = kwargs.get('connection_name', 'unknown')
454472
worker_id = kwargs.get('worker_id', 0)
455473

456474
try:
@@ -748,6 +766,80 @@ def _handle_reorg(self, invalidation_ranges: List[BlockRange], table_name: str,
748766
'Streaming with reorg detection requires implementing this method.'
749767
)
750768

769+
def _rewind_to_watermark(self, table_name: Optional[str] = None, connection_name: Optional[str] = None) -> None:
770+
"""
771+
Reset state and data to the last checkpointed watermark.
772+
773+
Removes any data written after the last completed watermark,
774+
ensuring resumable streams start from a consistent state.
775+
776+
This handles crash recovery by removing uncommitted data from
777+
incomplete microbatches between watermarks.
778+
779+
Args:
780+
table_name: Table to clean up. If None, processes all tables.
781+
connection_name: Connection identifier. If None, uses default.
782+
783+
Example:
784+
def connect(self):
785+
# Connect to database
786+
self._establish_connection()
787+
self._is_connected = True
788+
789+
# Crash recovery - clean up uncommitted data
790+
self._rewind_to_watermark()
791+
"""
792+
if not self.state_enabled:
793+
self.logger.debug('State tracking disabled, skipping crash recovery')
794+
return
795+
796+
if connection_name is None:
797+
connection_name = self.loader_type
798+
799+
tables_to_process = []
800+
if table_name is None:
801+
self.logger.debug('table_name=None not yet implemented, skipping crash recovery')
802+
return
803+
else:
804+
tables_to_process = [table_name]
805+
806+
for table in tables_to_process:
807+
resume_pos = self.state_store.get_resume_position(connection_name, table)
808+
if not resume_pos:
809+
self.logger.debug(f'No watermark found for {table}, skipping crash recovery')
810+
continue
811+
812+
for range_obj in resume_pos.ranges:
813+
from_block = range_obj.end + 1
814+
815+
self.logger.info(
816+
f'Crash recovery: Cleaning up {table} data for {range_obj.network} from block {from_block} onwards'
817+
)
818+
819+
# Create invalidation range for _handle_reorg()
820+
# Note: BlockRange requires 'end' field, but invalidate_from_block() only uses 'start'
821+
# Setting end=from_block is a valid placeholder since the actual range is open-ended
822+
invalidation_ranges = [BlockRange(network=range_obj.network, start=from_block, end=from_block)]
823+
824+
try:
825+
self._handle_reorg(invalidation_ranges, table, connection_name)
826+
self.logger.info(f'Crash recovery completed for {range_obj.network} in {table}')
827+
828+
except NotImplementedError:
829+
invalidated = self.state_store.invalidate_from_block(
830+
connection_name, table, range_obj.network, from_block
831+
)
832+
833+
if invalidated:
834+
self.logger.warning(
835+
f'Crash recovery: Cleared {len(invalidated)} batches from state '
836+
f'for {range_obj.network} but cannot delete data from {table}. '
837+
f'{self.__class__.__name__} does not support data deletion. '
838+
f'Duplicates may occur on resume.'
839+
)
840+
else:
841+
self.logger.debug(f'No uncommitted batches found for {range_obj.network}')
842+
751843
def _add_metadata_columns(self, data: pa.RecordBatch, block_ranges: List[BlockRange]) -> pa.RecordBatch:
752844
"""
753845
Add metadata columns for streaming data with compact batch identification.

tests/unit/test_crash_recovery.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
"""
2+
Unit tests for crash recovery via _rewind_to_watermark() method.
3+
4+
These tests verify the crash recovery logic works correctly in isolation.
5+
"""
6+
7+
from unittest.mock import Mock
8+
9+
import pytest
10+
11+
from src.amp.loaders.base import LoadResult
12+
from src.amp.streaming.types import BlockRange, ResumeWatermark
13+
from tests.fixtures.mock_clients import MockDataLoader
14+
15+
16+
@pytest.fixture
17+
def mock_loader() -> MockDataLoader:
18+
"""Create a mock loader with state store"""
19+
loader = MockDataLoader({'test': 'config'})
20+
loader.connect()
21+
22+
loader.state_store = Mock()
23+
loader.state_enabled = True
24+
25+
return loader
26+
27+
28+
@pytest.mark.unit
29+
class TestCrashRecovery:
30+
"""Test _rewind_to_watermark() crash recovery method"""
31+
32+
def test_rewind_with_no_state(self, mock_loader):
33+
"""Should return early if state_enabled=False"""
34+
mock_loader.state_enabled = False
35+
36+
mock_loader._rewind_to_watermark('test_table', 'test_conn')
37+
38+
mock_loader.state_store.get_resume_position.assert_not_called()
39+
40+
def test_rewind_with_no_watermark(self, mock_loader):
41+
"""Should return early if no watermark exists"""
42+
mock_loader.state_store.get_resume_position = Mock(return_value=None)
43+
44+
mock_loader._rewind_to_watermark('test_table', 'test_conn')
45+
46+
mock_loader.state_store.get_resume_position.assert_called_once_with('test_conn', 'test_table')
47+
48+
def test_rewind_calls_handle_reorg(self, mock_loader):
49+
"""Should call _handle_reorg with correct invalidation ranges"""
50+
watermark = ResumeWatermark(ranges=[BlockRange(network='ethereum', start=1000, end=1010, hash='0xabc')])
51+
mock_loader.state_store.get_resume_position = Mock(return_value=watermark)
52+
mock_loader._handle_reorg = Mock()
53+
54+
mock_loader._rewind_to_watermark('test_table', 'test_conn')
55+
56+
mock_loader._handle_reorg.assert_called_once()
57+
call_args = mock_loader._handle_reorg.call_args
58+
invalidation_ranges = call_args[0][0]
59+
assert len(invalidation_ranges) == 1
60+
assert invalidation_ranges[0].network == 'ethereum'
61+
assert invalidation_ranges[0].start == 1011
62+
assert call_args[0][1] == 'test_table'
63+
assert call_args[0][2] == 'test_conn'
64+
65+
def test_rewind_handles_not_implemented(self, mock_loader):
66+
"""Should gracefully handle loaders without _handle_reorg"""
67+
watermark = ResumeWatermark(ranges=[BlockRange(network='ethereum', start=1000, end=1010, hash='0xabc')])
68+
mock_loader.state_store.get_resume_position = Mock(return_value=watermark)
69+
mock_loader._handle_reorg = Mock(side_effect=NotImplementedError())
70+
mock_loader.state_store.invalidate_from_block = Mock(return_value=[])
71+
72+
mock_loader._rewind_to_watermark('test_table', 'test_conn')
73+
74+
mock_loader.state_store.invalidate_from_block.assert_called_once_with(
75+
'test_conn', 'test_table', 'ethereum', 1011
76+
)
77+
78+
def test_rewind_with_multiple_networks(self, mock_loader):
79+
"""Should process ethereum and polygon separately"""
80+
watermark = ResumeWatermark(
81+
ranges=[
82+
BlockRange(network='ethereum', start=1000, end=1010, hash='0xabc'),
83+
BlockRange(network='polygon', start=2000, end=2010, hash='0xdef'),
84+
]
85+
)
86+
mock_loader.state_store.get_resume_position = Mock(return_value=watermark)
87+
mock_loader._handle_reorg = Mock()
88+
89+
mock_loader._rewind_to_watermark('test_table', 'test_conn')
90+
91+
assert mock_loader._handle_reorg.call_count == 2
92+
93+
first_call = mock_loader._handle_reorg.call_args_list[0]
94+
assert first_call[0][0][0].network == 'ethereum'
95+
assert first_call[0][0][0].start == 1011
96+
97+
second_call = mock_loader._handle_reorg.call_args_list[1]
98+
assert second_call[0][0][0].network == 'polygon'
99+
assert second_call[0][0][0].start == 2011
100+
101+
def test_rewind_with_table_name_none(self, mock_loader):
102+
"""Should return early when table_name=None (not yet implemented)"""
103+
mock_loader._rewind_to_watermark(table_name=None, connection_name='test_conn')
104+
105+
mock_loader.state_store.get_resume_position.assert_not_called()
106+
107+
def test_rewind_uses_default_connection_name(self, mock_loader):
108+
"""Should use default connection name from loader class"""
109+
watermark = ResumeWatermark(ranges=[BlockRange(network='ethereum', start=1000, end=1010, hash='0xabc')])
110+
mock_loader.state_store.get_resume_position = Mock(return_value=watermark)
111+
mock_loader._handle_reorg = Mock()
112+
113+
mock_loader._rewind_to_watermark('test_table', connection_name=None)
114+
115+
mock_loader.state_store.get_resume_position.assert_called_once_with('mockdata', 'test_table')

0 commit comments

Comments
 (0)