Skip to content

Commit 319fe6c

Browse files
committed
Add crash recovery via _rewind_to_watermark
On stream start, automatically clean up any data written after the last checkpoint watermark. This handles crash scenarios where data was written but the checkpoint was not saved. The _rewind_to_watermark method: 1. Gets the last watermark from state store 2. Creates invalidation ranges for blocks after the watermark 3. Calls _handle_reorg to delete uncommitted data 4. Falls back gracefully if loader does not support deletion Called automatically at start of load_stream_continuous once per table.
1 parent 94c6d76 commit 319fe6c

File tree

2 files changed

+208
-1
lines changed

2 files changed

+208
-1
lines changed

src/amp/loaders/base.py

Lines changed: 93 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,11 +77,19 @@ def __init__(self, config: Dict[str, Any], label_manager=None) -> None:
7777
else:
7878
self.state_store = NullStreamStateStore()
7979

80+
# Track tables that have undergone crash recovery
81+
self._crash_recovery_done: set[str] = set()
82+
8083
@property
8184
def is_connected(self) -> bool:
8285
"""Check if the loader is connected to the target system."""
8386
return self._is_connected
8487

88+
@property
89+
def loader_type(self) -> str:
90+
"""Get the loader type identifier (e.g., 'postgresql', 'redis')."""
91+
return self.__class__.__name__.replace('Loader', '').lower()
92+
8593
def _parse_config(self, config: Dict[str, Any]) -> TConfig:
8694
"""
8795
Parse configuration into loader-specific format.
@@ -446,11 +454,21 @@ def load_stream_continuous(
446454
if not self._is_connected:
447455
self.connect()
448456

457+
connection_name = kwargs.get('connection_name')
458+
if connection_name is None:
459+
connection_name = self.loader_type
460+
461+
if table_name not in self._crash_recovery_done:
462+
self.logger.info(f'Running crash recovery for table {table_name} (connection: {connection_name})')
463+
self._rewind_to_watermark(table_name, connection_name)
464+
self._crash_recovery_done.add(table_name)
465+
else:
466+
self.logger.info(f'Crash recovery already done for table {table_name}')
467+
449468
rows_loaded = 0
450469
start_time = time.time()
451470
batch_count = 0
452471
reorg_count = 0
453-
connection_name = kwargs.get('connection_name', 'unknown')
454472
worker_id = kwargs.get('worker_id', 0)
455473

456474
try:
@@ -784,6 +802,80 @@ def _handle_reorg(self, invalidation_ranges: List[BlockRange], table_name: str,
784802
'Streaming with reorg detection requires implementing this method.'
785803
)
786804

805+
def _rewind_to_watermark(self, table_name: Optional[str] = None, connection_name: Optional[str] = None) -> None:
806+
"""
807+
Reset state and data to the last checkpointed watermark.
808+
809+
Removes any data written after the last completed watermark,
810+
ensuring resumable streams start from a consistent state.
811+
812+
This handles crash recovery by removing uncommitted data from
813+
incomplete microbatches between watermarks.
814+
815+
Args:
816+
table_name: Table to clean up. If None, processes all tables.
817+
connection_name: Connection identifier. If None, uses default.
818+
819+
Example:
820+
def connect(self):
821+
# Connect to database
822+
self._establish_connection()
823+
self._is_connected = True
824+
825+
# Crash recovery - clean up uncommitted data
826+
self._rewind_to_watermark()
827+
"""
828+
if not self.state_enabled:
829+
self.logger.debug('State tracking disabled, skipping crash recovery')
830+
return
831+
832+
if connection_name is None:
833+
connection_name = self.loader_type
834+
835+
tables_to_process = []
836+
if table_name is None:
837+
self.logger.debug('table_name=None not yet implemented, skipping crash recovery')
838+
return
839+
else:
840+
tables_to_process = [table_name]
841+
842+
for table in tables_to_process:
843+
resume_pos = self.state_store.get_resume_position(connection_name, table)
844+
if not resume_pos:
845+
self.logger.debug(f'No watermark found for {table}, skipping crash recovery')
846+
continue
847+
848+
for range_obj in resume_pos.ranges:
849+
from_block = range_obj.end + 1
850+
851+
self.logger.info(
852+
f'Crash recovery: Cleaning up {table} data for {range_obj.network} from block {from_block} onwards'
853+
)
854+
855+
# Create invalidation range for _handle_reorg()
856+
# Note: BlockRange requires 'end' field, but invalidate_from_block() only uses 'start'
857+
# Setting end=from_block is a valid placeholder since the actual range is open-ended
858+
invalidation_ranges = [BlockRange(network=range_obj.network, start=from_block, end=from_block)]
859+
860+
try:
861+
self._handle_reorg(invalidation_ranges, table, connection_name)
862+
self.logger.info(f'Crash recovery completed for {range_obj.network} in {table}')
863+
864+
except NotImplementedError:
865+
invalidated = self.state_store.invalidate_from_block(
866+
connection_name, table, range_obj.network, from_block
867+
)
868+
869+
if invalidated:
870+
self.logger.warning(
871+
f'Crash recovery: Cleared {len(invalidated)} batches from state '
872+
f'for {range_obj.network} but cannot delete data from {table}. '
873+
f'{self.__class__.__name__} does not support data deletion. '
874+
f'Duplicates may occur on resume.'
875+
)
876+
else:
877+
self.logger.debug(f'No uncommitted batches found for {range_obj.network}')
878+
787879
def _add_metadata_columns(self, data: pa.RecordBatch, block_ranges: List[BlockRange]) -> pa.RecordBatch:
788880
"""
789881
Add metadata columns for streaming data with compact batch identification.

tests/unit/test_crash_recovery.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
"""
2+
Unit tests for crash recovery via _rewind_to_watermark() method.
3+
4+
These tests verify the crash recovery logic works correctly in isolation.
5+
"""
6+
7+
from unittest.mock import Mock
8+
9+
import pytest
10+
11+
from src.amp.loaders.base import LoadResult
12+
from src.amp.streaming.types import BlockRange, ResumeWatermark
13+
from tests.fixtures.mock_clients import MockDataLoader
14+
15+
16+
@pytest.fixture
17+
def mock_loader() -> MockDataLoader:
18+
"""Create a mock loader with state store"""
19+
loader = MockDataLoader({'test': 'config'})
20+
loader.connect()
21+
22+
loader.state_store = Mock()
23+
loader.state_enabled = True
24+
25+
return loader
26+
27+
28+
@pytest.mark.unit
29+
class TestCrashRecovery:
30+
"""Test _rewind_to_watermark() crash recovery method"""
31+
32+
def test_rewind_with_no_state(self, mock_loader):
33+
"""Should return early if state_enabled=False"""
34+
mock_loader.state_enabled = False
35+
36+
mock_loader._rewind_to_watermark('test_table', 'test_conn')
37+
38+
mock_loader.state_store.get_resume_position.assert_not_called()
39+
40+
def test_rewind_with_no_watermark(self, mock_loader):
41+
"""Should return early if no watermark exists"""
42+
mock_loader.state_store.get_resume_position = Mock(return_value=None)
43+
44+
mock_loader._rewind_to_watermark('test_table', 'test_conn')
45+
46+
mock_loader.state_store.get_resume_position.assert_called_once_with('test_conn', 'test_table')
47+
48+
def test_rewind_calls_handle_reorg(self, mock_loader):
49+
"""Should call _handle_reorg with correct invalidation ranges"""
50+
watermark = ResumeWatermark(ranges=[BlockRange(network='ethereum', start=1000, end=1010, hash='0xabc')])
51+
mock_loader.state_store.get_resume_position = Mock(return_value=watermark)
52+
mock_loader._handle_reorg = Mock()
53+
54+
mock_loader._rewind_to_watermark('test_table', 'test_conn')
55+
56+
mock_loader._handle_reorg.assert_called_once()
57+
call_args = mock_loader._handle_reorg.call_args
58+
invalidation_ranges = call_args[0][0]
59+
assert len(invalidation_ranges) == 1
60+
assert invalidation_ranges[0].network == 'ethereum'
61+
assert invalidation_ranges[0].start == 1011
62+
assert call_args[0][1] == 'test_table'
63+
assert call_args[0][2] == 'test_conn'
64+
65+
def test_rewind_handles_not_implemented(self, mock_loader):
66+
"""Should gracefully handle loaders without _handle_reorg"""
67+
watermark = ResumeWatermark(ranges=[BlockRange(network='ethereum', start=1000, end=1010, hash='0xabc')])
68+
mock_loader.state_store.get_resume_position = Mock(return_value=watermark)
69+
mock_loader._handle_reorg = Mock(side_effect=NotImplementedError())
70+
mock_loader.state_store.invalidate_from_block = Mock(return_value=[])
71+
72+
mock_loader._rewind_to_watermark('test_table', 'test_conn')
73+
74+
mock_loader.state_store.invalidate_from_block.assert_called_once_with(
75+
'test_conn', 'test_table', 'ethereum', 1011
76+
)
77+
78+
def test_rewind_with_multiple_networks(self, mock_loader):
79+
"""Should process ethereum and polygon separately"""
80+
watermark = ResumeWatermark(
81+
ranges=[
82+
BlockRange(network='ethereum', start=1000, end=1010, hash='0xabc'),
83+
BlockRange(network='polygon', start=2000, end=2010, hash='0xdef'),
84+
]
85+
)
86+
mock_loader.state_store.get_resume_position = Mock(return_value=watermark)
87+
mock_loader._handle_reorg = Mock()
88+
89+
mock_loader._rewind_to_watermark('test_table', 'test_conn')
90+
91+
assert mock_loader._handle_reorg.call_count == 2
92+
93+
first_call = mock_loader._handle_reorg.call_args_list[0]
94+
assert first_call[0][0][0].network == 'ethereum'
95+
assert first_call[0][0][0].start == 1011
96+
97+
second_call = mock_loader._handle_reorg.call_args_list[1]
98+
assert second_call[0][0][0].network == 'polygon'
99+
assert second_call[0][0][0].start == 2011
100+
101+
def test_rewind_with_table_name_none(self, mock_loader):
102+
"""Should return early when table_name=None (not yet implemented)"""
103+
mock_loader._rewind_to_watermark(table_name=None, connection_name='test_conn')
104+
105+
mock_loader.state_store.get_resume_position.assert_not_called()
106+
107+
def test_rewind_uses_default_connection_name(self, mock_loader):
108+
"""Should use default connection name from loader class"""
109+
watermark = ResumeWatermark(ranges=[BlockRange(network='ethereum', start=1000, end=1010, hash='0xabc')])
110+
mock_loader.state_store.get_resume_position = Mock(return_value=watermark)
111+
mock_loader._handle_reorg = Mock()
112+
113+
mock_loader._rewind_to_watermark('test_table', connection_name=None)
114+
115+
mock_loader.state_store.get_resume_position.assert_called_once_with('mockdata', 'test_table')

0 commit comments

Comments
 (0)