Skip to content

Commit c8c0b0e

Browse files
fordNincrypto32
andauthored
base loader: fix micro batch is_processed marking, add tests (#31)
* base loader: fix micro batch is_processed marking, add tests * fix: update reorg tests to set ranges_complete=True for proper state tracking The recent microbatch processing changes require ranges_complete=True for batches to be tracked in the state store. This fixes all reorg handling tests by ensuring test batches are properly marked as complete, allowing the reorg deletion logic to find and remove the appropriate data. - Updated 16 reorg-related tests across 4 loader implementations - All test batches now set ranges_complete=True in BatchMetadata - Ensures accurate testing of real-world reorg handling behavior * fix: update unit tests for ranges_complete parameter Updated unit tests to account for the ranges_complete parameter that controls when batches are marked as processed and when duplicate checking occurs. Tests now correctly pass ranges_complete=True when testing duplicate detection and state management. --------- Co-authored-by: Krishnanand V P <[email protected]>
1 parent 710c4e3 commit c8c0b0e

File tree

8 files changed

+439
-58
lines changed

8 files changed

+439
-58
lines changed

src/amp/loaders/base.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,7 @@ def load_stream_continuous(
484484
table_name,
485485
connection_name,
486486
response.metadata.ranges,
487+
ranges_complete=response.metadata.ranges_complete,
487488
)
488489
else:
489490
# Non-transactional loading (separate check, load, mark)
@@ -494,6 +495,7 @@ def load_stream_continuous(
494495
table_name,
495496
connection_name,
496497
response.metadata.ranges,
498+
ranges_complete=response.metadata.ranges_complete,
497499
**filtered_kwargs,
498500
)
499501

@@ -611,6 +613,7 @@ def _process_batch_transactional(
611613
table_name: str,
612614
connection_name: str,
613615
ranges: List[BlockRange],
616+
ranges_complete: bool = False,
614617
) -> LoadResult:
615618
"""
616619
Process a data batch using transactional exactly-once semantics.
@@ -622,6 +625,7 @@ def _process_batch_transactional(
622625
table_name: Target table name
623626
connection_name: Connection identifier
624627
ranges: Block ranges for this batch
628+
ranges_complete: True when this RecordBatch completes a microbatch (streaming only)
625629
626630
Returns:
627631
LoadResult with operation outcome
@@ -630,13 +634,17 @@ def _process_batch_transactional(
630634
try:
631635
# Delegate to loader-specific transactional implementation
632636
# Loaders that support transactions implement load_batch_transactional()
633-
rows_loaded_batch = self.load_batch_transactional(batch_data, table_name, connection_name, ranges)
637+
rows_loaded_batch = self.load_batch_transactional(
638+
batch_data, table_name, connection_name, ranges, ranges_complete
639+
)
634640
duration = time.time() - start_time
635641

636-
# Mark batches as processed in state store after successful transaction
637-
if ranges:
642+
# Mark batches as processed ONLY when microbatch is complete
643+
# multiple RecordBatches can share the same microbatch ID
644+
if ranges and ranges_complete:
638645
batch_ids = [BatchIdentifier.from_block_range(br) for br in ranges]
639646
self.state_store.mark_processed(connection_name, table_name, batch_ids)
647+
self.logger.debug(f'Marked microbatch as processed: {len(batch_ids)} batch IDs')
640648

641649
return LoadResult(
642650
rows_loaded=rows_loaded_batch,
@@ -648,6 +656,7 @@ def _process_batch_transactional(
648656
metadata={
649657
'operation': 'transactional_load' if rows_loaded_batch > 0 else 'skip_duplicate',
650658
'ranges': [r.to_dict() for r in ranges],
659+
'ranges_complete': ranges_complete,
651660
},
652661
)
653662

@@ -670,6 +679,7 @@ def _process_batch_non_transactional(
670679
table_name: str,
671680
connection_name: str,
672681
ranges: Optional[List[BlockRange]],
682+
ranges_complete: bool = False,
673683
**kwargs,
674684
) -> Optional[LoadResult]:
675685
"""
@@ -682,21 +692,25 @@ def _process_batch_non_transactional(
682692
table_name: Target table name
683693
connection_name: Connection identifier
684694
ranges: Block ranges for this batch (if available)
695+
ranges_complete: True when this RecordBatch completes a microbatch (streaming only)
685696
**kwargs: Additional options passed to load_batch
686697
687698
Returns:
688699
LoadResult, or None if batch was skipped as duplicate
689700
"""
690701
# Check if batch already processed (idempotency / exactly-once)
691-
if ranges and self.state_enabled:
702+
# For streaming: only check when ranges_complete=True (end of microbatch)
703+
# Multiple RecordBatches can share the same microbatch ID, so we must wait
704+
# until the entire microbatch is delivered before checking/marking as processed
705+
if ranges and self.state_enabled and ranges_complete:
692706
try:
693707
batch_ids = [BatchIdentifier.from_block_range(br) for br in ranges]
694708
is_duplicate = self.state_store.is_processed(connection_name, table_name, batch_ids)
695709

696710
if is_duplicate:
697711
# Skip this batch - already processed
698712
self.logger.info(
699-
f'Skipping duplicate batch: {len(ranges)} ranges already processed for {table_name}'
713+
f'Skipping duplicate microbatch: {len(ranges)} ranges already processed for {table_name}'
700714
)
701715
return LoadResult(
702716
rows_loaded=0,
@@ -711,14 +725,16 @@ def _process_batch_non_transactional(
711725
# BlockRange missing hash - log and continue without idempotency check
712726
self.logger.warning(f'Cannot check for duplicates: {e}. Processing batch anyway.')
713727

714-
# Load batch
728+
# Load batch (always load, even if part of larger microbatch)
715729
result = self.load_batch(batch_data, table_name, **kwargs)
716730

717-
if result.success and ranges and self.state_enabled:
718-
# Mark batch as processed (for exactly-once semantics)
731+
# Mark batch as processed ONLY when microbatch is complete
732+
# This ensures we don't skip subsequent RecordBatches within the same microbatch
733+
if result.success and ranges and self.state_enabled and ranges_complete:
719734
try:
720735
batch_ids = [BatchIdentifier.from_block_range(br) for br in ranges]
721736
self.state_store.mark_processed(connection_name, table_name, batch_ids)
737+
self.logger.debug(f'Marked microbatch as processed: {len(batch_ids)} batch IDs')
722738
except Exception as e:
723739
self.logger.error(f'Failed to mark batches as processed: {e}')
724740
# Continue anyway - state store provides resume capability

src/amp/loaders/implementations/postgresql_loader.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ def load_batch_transactional(
119119
table_name: str,
120120
connection_name: str,
121121
ranges: List[BlockRange],
122+
ranges_complete: bool = False,
122123
) -> int:
123124
"""
124125
Load a batch with transactional exactly-once semantics using in-memory state.
@@ -135,6 +136,7 @@ def load_batch_transactional(
135136
table_name: Target table name
136137
connection_name: Connection identifier for tracking
137138
ranges: Block ranges covered by this batch
139+
ranges_complete: True when this RecordBatch completes a microbatch (streaming only)
138140
139141
Returns:
140142
Number of rows loaded (0 if duplicate)
@@ -149,24 +151,27 @@ def load_batch_transactional(
149151
self.logger.warning(f'Cannot create batch identifiers: {e}. Loading without duplicate check.')
150152
batch_ids = []
151153

152-
# Check if already processed (using in-memory state)
153-
if batch_ids and self.state_store.is_processed(connection_name, table_name, batch_ids):
154+
# Check if already processed ONLY when microbatch is complete
155+
# Multiple RecordBatches can share the same microbatch ID (BlockRange)
156+
if batch_ids and ranges_complete and self.state_store.is_processed(connection_name, table_name, batch_ids):
154157
self.logger.info(
155158
f'Batch already processed (ranges: {[f"{r.network}:{r.start}-{r.end}" for r in ranges]}), '
156159
f'skipping (state check)'
157160
)
158161
return 0
159162

160-
# Load data
163+
# Load data (always load, even if part of larger microbatch)
161164
conn = self.pool.getconn()
162165
try:
163166
with conn.cursor() as cur:
164167
self._copy_arrow_data(cur, batch, table_name)
165168
conn.commit()
166169

167-
# Mark as processed after successful load
168-
if batch_ids:
170+
# Mark as processed ONLY when microbatch is complete
171+
# This ensures we don't skip subsequent RecordBatches within the same microbatch
172+
if batch_ids and ranges_complete:
169173
self.state_store.mark_processed(connection_name, table_name, batch_ids)
174+
self.logger.debug(f'Marked microbatch as processed: {len(batch_ids)} batch IDs')
170175

171176
self.logger.debug(
172177
f'Batch load committed: {batch.num_rows} rows, '

tests/integration/test_deltalake_loader.py

Lines changed: 60 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -586,15 +586,24 @@ def test_handle_reorg_single_network(self, delta_temp_config):
586586
# Create response batches with hashes
587587
response1 = ResponseBatch.data_batch(
588588
data=batch1,
589-
metadata=BatchMetadata(ranges=[BlockRange(network='ethereum', start=100, end=110, hash='0xabc')]),
589+
metadata=BatchMetadata(
590+
ranges=[BlockRange(network='ethereum', start=100, end=110, hash='0xabc')],
591+
ranges_complete=True, # Mark as complete so it gets tracked in state store
592+
),
590593
)
591594
response2 = ResponseBatch.data_batch(
592595
data=batch2,
593-
metadata=BatchMetadata(ranges=[BlockRange(network='ethereum', start=150, end=160, hash='0xdef')]),
596+
metadata=BatchMetadata(
597+
ranges=[BlockRange(network='ethereum', start=150, end=160, hash='0xdef')],
598+
ranges_complete=True, # Mark as complete so it gets tracked in state store
599+
),
594600
)
595601
response3 = ResponseBatch.data_batch(
596602
data=batch3,
597-
metadata=BatchMetadata(ranges=[BlockRange(network='ethereum', start=200, end=210, hash='0x123')]),
603+
metadata=BatchMetadata(
604+
ranges=[BlockRange(network='ethereum', start=200, end=210, hash='0x123')],
605+
ranges_complete=True, # Mark as complete so it gets tracked in state store
606+
),
598607
)
599608

600609
# Load via streaming API
@@ -637,19 +646,31 @@ def test_handle_reorg_multi_network(self, delta_temp_config):
637646
# Create response batches with network-specific ranges
638647
response1 = ResponseBatch.data_batch(
639648
data=batch1,
640-
metadata=BatchMetadata(ranges=[BlockRange(network='ethereum', start=100, end=110, hash='0xaaa')]),
649+
metadata=BatchMetadata(
650+
ranges=[BlockRange(network='ethereum', start=100, end=110, hash='0xaaa')],
651+
ranges_complete=True, # Mark as complete so it gets tracked in state store
652+
),
641653
)
642654
response2 = ResponseBatch.data_batch(
643655
data=batch2,
644-
metadata=BatchMetadata(ranges=[BlockRange(network='polygon', start=100, end=110, hash='0xbbb')]),
656+
metadata=BatchMetadata(
657+
ranges=[BlockRange(network='polygon', start=100, end=110, hash='0xbbb')],
658+
ranges_complete=True, # Mark as complete so it gets tracked in state store
659+
),
645660
)
646661
response3 = ResponseBatch.data_batch(
647662
data=batch3,
648-
metadata=BatchMetadata(ranges=[BlockRange(network='ethereum', start=150, end=160, hash='0xccc')]),
663+
metadata=BatchMetadata(
664+
ranges=[BlockRange(network='ethereum', start=150, end=160, hash='0xccc')],
665+
ranges_complete=True, # Mark as complete so it gets tracked in state store
666+
),
649667
)
650668
response4 = ResponseBatch.data_batch(
651669
data=batch4,
652-
metadata=BatchMetadata(ranges=[BlockRange(network='polygon', start=150, end=160, hash='0xddd')]),
670+
metadata=BatchMetadata(
671+
ranges=[BlockRange(network='polygon', start=150, end=160, hash='0xddd')],
672+
ranges_complete=True, # Mark as complete so it gets tracked in state store
673+
),
653674
)
654675

655676
# Load via streaming API
@@ -689,15 +710,24 @@ def test_handle_reorg_overlapping_ranges(self, delta_temp_config):
689710
# Batch 3: 170-190 (after reorg, but should be deleted as 170 >= 150)
690711
response1 = ResponseBatch.data_batch(
691712
data=batch1,
692-
metadata=BatchMetadata(ranges=[BlockRange(network='ethereum', start=90, end=110, hash='0xaaa')]),
713+
metadata=BatchMetadata(
714+
ranges=[BlockRange(network='ethereum', start=90, end=110, hash='0xaaa')],
715+
ranges_complete=True, # Mark as complete so it gets tracked in state store
716+
),
693717
)
694718
response2 = ResponseBatch.data_batch(
695719
data=batch2,
696-
metadata=BatchMetadata(ranges=[BlockRange(network='ethereum', start=140, end=160, hash='0xbbb')]),
720+
metadata=BatchMetadata(
721+
ranges=[BlockRange(network='ethereum', start=140, end=160, hash='0xbbb')],
722+
ranges_complete=True, # Mark as complete so it gets tracked in state store
723+
),
697724
)
698725
response3 = ResponseBatch.data_batch(
699726
data=batch3,
700-
metadata=BatchMetadata(ranges=[BlockRange(network='ethereum', start=170, end=190, hash='0xccc')]),
727+
metadata=BatchMetadata(
728+
ranges=[BlockRange(network='ethereum', start=170, end=190, hash='0xccc')],
729+
ranges_complete=True, # Mark as complete so it gets tracked in state store
730+
),
701731
)
702732

703733
# Load via streaming API
@@ -733,15 +763,24 @@ def test_handle_reorg_version_history(self, delta_temp_config):
733763

734764
response1 = ResponseBatch.data_batch(
735765
data=batch1,
736-
metadata=BatchMetadata(ranges=[BlockRange(network='ethereum', start=0, end=10, hash='0xaaa')]),
766+
metadata=BatchMetadata(
767+
ranges=[BlockRange(network='ethereum', start=0, end=10, hash='0xaaa')],
768+
ranges_complete=True, # Mark as complete so it gets tracked in state store
769+
),
737770
)
738771
response2 = ResponseBatch.data_batch(
739772
data=batch2,
740-
metadata=BatchMetadata(ranges=[BlockRange(network='ethereum', start=50, end=60, hash='0xbbb')]),
773+
metadata=BatchMetadata(
774+
ranges=[BlockRange(network='ethereum', start=50, end=60, hash='0xbbb')],
775+
ranges_complete=True, # Mark as complete so it gets tracked in state store
776+
),
741777
)
742778
response3 = ResponseBatch.data_batch(
743779
data=batch3,
744-
metadata=BatchMetadata(ranges=[BlockRange(network='ethereum', start=100, end=110, hash='0xccc')]),
780+
metadata=BatchMetadata(
781+
ranges=[BlockRange(network='ethereum', start=100, end=110, hash='0xccc')],
782+
ranges_complete=True, # Mark as complete so it gets tracked in state store
783+
),
745784
)
746785

747786
# Load via streaming API
@@ -792,12 +831,18 @@ def test_streaming_with_reorg(self, delta_temp_config):
792831
# Create response batches using factory methods (with hashes for proper state management)
793832
response1 = ResponseBatch.data_batch(
794833
data=data1,
795-
metadata=BatchMetadata(ranges=[BlockRange(network='ethereum', start=100, end=110, hash='0xabc123')]),
834+
metadata=BatchMetadata(
835+
ranges=[BlockRange(network='ethereum', start=100, end=110, hash='0xabc123')],
836+
ranges_complete=True, # Mark as complete so it gets tracked in state store
837+
),
796838
)
797839

798840
response2 = ResponseBatch.data_batch(
799841
data=data2,
800-
metadata=BatchMetadata(ranges=[BlockRange(network='ethereum', start=150, end=160, hash='0xdef456')]),
842+
metadata=BatchMetadata(
843+
ranges=[BlockRange(network='ethereum', start=150, end=160, hash='0xdef456')],
844+
ranges_complete=True, # Mark as complete so it gets tracked in state store
845+
),
801846
)
802847

803848
# Simulate reorg event using factory method

0 commit comments

Comments
 (0)