Skip to content

Commit 86749a0

Browse files
committed
refactor: Update loaders for new base class interface
- PostgreSQL: Add reorg support with DELETE/UPDATE, metadata columns - Redis: Add streaming metadata and batch ID support - DeltaLake: Support new metadata columns - Iceberg: Update for base class changes - LMDB: Add metadata column support All loaders now support: - State-backed resume and deduplication - Label joining via base class - Resilience features (retry, backpressure) - Reorg-aware streaming with metadata tracking
1 parent 9fdb633 commit 86749a0

File tree

5 files changed

+269
-187
lines changed

5 files changed

+269
-187
lines changed

src/amp/loaders/implementations/deltalake_loader.py

Lines changed: 32 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -80,11 +80,11 @@ class DeltaLakeLoader(DataLoader[DeltaStorageConfig]):
8080
REQUIRES_SCHEMA_MATCH = False
8181
SUPPORTS_TRANSACTIONS = True
8282

83-
def __init__(self, config: Dict[str, Any]):
83+
def __init__(self, config: Dict[str, Any], label_manager=None):
8484
if not DELTALAKE_AVAILABLE:
8585
raise ImportError("Delta Lake support requires 'deltalake' package. Install with: pip install deltalake")
8686

87-
super().__init__(config)
87+
super().__init__(config, label_manager=label_manager)
8888

8989
# Performance settings
9090
self.batch_size = config.get('batch_size', 10000)
@@ -644,17 +644,16 @@ def query_table(self, columns: Optional[List[str]] = None, limit: Optional[int]
644644
self.logger.error(f'Query failed: {e}')
645645
raise
646646

647-
def _handle_reorg(self, invalidation_ranges: List[BlockRange], table_name: str) -> None:
647+
def _handle_reorg(self, invalidation_ranges: List[BlockRange], table_name: str, connection_name: str) -> None:
648648
"""
649649
Handle blockchain reorganization by deleting affected rows from Delta Lake.
650650
651-
Delta Lake's versioning and transaction capabilities make this operation
652-
particularly powerful - we can precisely delete affected data and even
653-
roll back if needed using time travel features.
651+
Uses the _amp_batch_id column for fast, indexed deletion of affected batches.
654652
655653
Args:
656654
invalidation_ranges: List of block ranges to invalidate (reorg points)
657655
table_name: The table containing the data to invalidate (not used but kept for API consistency)
656+
connection_name: The connection name (for state invalidation)
658657
"""
659658
if not invalidation_ranges:
660659
return
@@ -665,62 +664,51 @@ def _handle_reorg(self, invalidation_ranges: List[BlockRange], table_name: str)
665664
self.logger.warning('No Delta table connected, skipping reorg handling')
666665
return
667666

667+
# Get affected batch IDs from state store
668+
all_affected_batch_ids = []
669+
for range_obj in invalidation_ranges:
670+
affected_batch_ids = self.state_store.invalidate_from_block(
671+
connection_name, table_name, range_obj.network, range_obj.start
672+
)
673+
all_affected_batch_ids.extend(affected_batch_ids)
674+
675+
if not all_affected_batch_ids:
676+
self.logger.info('No batches found to invalidate')
677+
return
678+
668679
# Load the current table data
669680
current_table = self._delta_table.to_pyarrow_table()
670681

671-
# Check if the table has metadata column
672-
if '_meta_block_ranges' not in current_table.schema.names:
673-
self.logger.warning("Delta table doesn't have '_meta_block_ranges' column, skipping reorg handling")
682+
# Check if the table has batch_id column
683+
if '_amp_batch_id' not in current_table.schema.names:
684+
self.logger.warning("Delta table doesn't have '_amp_batch_id' column, skipping reorg handling")
674685
return
675686

676687
# Build a mask to identify rows to keep
688+
batch_id_column = current_table['_amp_batch_id']
677689
keep_mask = pa.array([True] * current_table.num_rows)
678690

679-
# Process each row to check if it should be invalidated
680-
meta_column = current_table['_meta_block_ranges']
681-
691+
# Mark rows for deletion if their batch_id matches any affected batch
692+
batch_id_set = {bid.unique_id for bid in all_affected_batch_ids}
682693
for i in range(current_table.num_rows):
683-
meta_json = meta_column[i].as_py()
684-
685-
if meta_json:
686-
try:
687-
ranges_data = json.loads(meta_json)
688-
689-
# Ensure ranges_data is a list
690-
if not isinstance(ranges_data, list):
691-
continue
692-
693-
# Check each invalidation range
694-
for range_obj in invalidation_ranges:
695-
network = range_obj.network
696-
reorg_start = range_obj.start
697-
698-
# Check if any range for this network should be invalidated
699-
for range_info in ranges_data:
700-
if (
701-
isinstance(range_info, dict)
702-
and range_info.get('network') == network
703-
and range_info.get('end', 0) >= reorg_start
704-
):
705-
# Mark this row for deletion
706-
# Create a mask for this specific row
707-
row_mask = pa.array([j == i for j in range(current_table.num_rows)])
708-
keep_mask = pa.compute.and_(keep_mask, pa.compute.invert(row_mask))
709-
break
710-
711-
except (json.JSONDecodeError, KeyError):
712-
pass
694+
batch_id_str = batch_id_column[i].as_py()
695+
if batch_id_str:
696+
# Check if any of the batch IDs in this row match affected batches
697+
for batch_id in batch_id_str.split('|'):
698+
if batch_id in batch_id_set:
699+
row_mask = pa.array([j == i for j in range(current_table.num_rows)])
700+
keep_mask = pa.compute.and_(keep_mask, pa.compute.invert(row_mask))
701+
break
713702

714703
# Filter the table to keep only valid rows
715704
filtered_table = current_table.filter(keep_mask)
716705
deleted_count = current_table.num_rows - filtered_table.num_rows
717706

718707
if deleted_count > 0:
719708
# Overwrite the table with filtered data
720-
# This creates a new version in Delta Lake, preserving history
721709
self.logger.info(
722710
f'Executing blockchain reorg deletion for {len(invalidation_ranges)} networks '
723-
f'in Delta Lake table. Deleting {deleted_count} rows.'
711+
f'in Delta Lake table. Deleting {deleted_count} rows affected by {len(all_affected_batch_ids)} batches.'
724712
)
725713

726714
# Use overwrite mode to replace table contents

src/amp/loaders/implementations/iceberg_loader.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -76,13 +76,13 @@ class IcebergLoader(DataLoader[IcebergStorageConfig]):
7676
REQUIRES_SCHEMA_MATCH = False
7777
SUPPORTS_TRANSACTIONS = True
7878

79-
def __init__(self, config: Dict[str, Any]):
79+
def __init__(self, config: Dict[str, Any], label_manager=None):
8080
if not ICEBERG_AVAILABLE:
8181
raise ImportError(
8282
"Apache Iceberg support requires 'pyiceberg' package. Install with: pip install pyiceberg"
8383
)
8484

85-
super().__init__(config)
85+
super().__init__(config, label_manager=label_manager)
8686

8787
self._catalog: Optional[IcebergCatalog] = None
8888
self._current_table: Optional[IcebergTable] = None
@@ -283,7 +283,7 @@ def _validate_schema_compatibility(self, iceberg_table: IcebergTable, arrow_sche
283283
# Evolution mode: evolve schema to accommodate new fields
284284
self._evolve_schema_if_needed(iceberg_table, iceberg_schema, arrow_schema)
285285

286-
def _validate_schema_strict(self, iceberg_schema: IcebergSchema, arrow_schema: pa.Schema) -> None:
286+
def _validate_schema_strict(self, iceberg_schema: 'IcebergSchema', arrow_schema: pa.Schema) -> None:
287287
"""Validate schema compatibility in strict mode (no evolution)"""
288288
iceberg_field_names = {field.name for field in iceberg_schema.fields}
289289
arrow_field_names = {field.name for field in arrow_schema}
@@ -304,7 +304,7 @@ def _validate_schema_strict(self, iceberg_schema: IcebergSchema, arrow_schema: p
304304
self.logger.debug('Schema validation passed in strict mode')
305305

306306
def _evolve_schema_if_needed(
307-
self, iceberg_table: IcebergTable, iceberg_schema: IcebergSchema, arrow_schema: pa.Schema
307+
self, iceberg_table: 'IcebergTable', iceberg_schema: 'IcebergSchema', arrow_schema: pa.Schema
308308
) -> None:
309309
"""Evolve the Iceberg table schema to accommodate new Arrow schema fields"""
310310
try:
@@ -506,7 +506,7 @@ def get_table_info(self, table_name: str) -> Dict[str, Any]:
506506
self.logger.error(f'Failed to get table info for {table_name}: {e}')
507507
return {'exists': False, 'error': str(e), 'table_name': table_name}
508508

509-
def _handle_reorg(self, invalidation_ranges: List[BlockRange], table_name: str) -> None:
509+
def _handle_reorg(self, invalidation_ranges: List[BlockRange], table_name: str, connection_name: str) -> None:
510510
"""
511511
Handle blockchain reorganization by deleting affected rows from Iceberg table.
512512
@@ -518,6 +518,7 @@ def _handle_reorg(self, invalidation_ranges: List[BlockRange], table_name: str)
518518
Args:
519519
invalidation_ranges: List of block ranges to invalidate (reorg points)
520520
table_name: The table containing the data to invalidate
521+
connection_name: The connection name (for state invalidation)
521522
"""
522523
if not invalidation_ranges:
523524
return

src/amp/loaders/implementations/lmdb_loader.py

Lines changed: 34 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,8 @@ class LMDBLoader(DataLoader[LMDBConfig]):
6464
REQUIRES_SCHEMA_MATCH = False
6565
SUPPORTS_TRANSACTIONS = True
6666

67-
def __init__(self, config: Dict[str, Any]):
68-
super().__init__(config)
67+
def __init__(self, config: Dict[str, Any], label_manager=None):
68+
super().__init__(config, label_manager=label_manager)
6969

7070
self.env: Optional[lmdb.Environment] = None
7171
self.dbs: Dict[str, Any] = {} # Cache opened databases
@@ -350,75 +350,67 @@ def get_table_info(self, table_name: str) -> Optional[Dict[str, Any]]:
350350
self.logger.error(f'Failed to get table info: {e}')
351351
return None
352352

353-
def _handle_reorg(self, invalidation_ranges: List[BlockRange], table_name: str) -> None:
353+
def _handle_reorg(self, invalidation_ranges: List[BlockRange], table_name: str, connection_name: str) -> None:
354354
"""
355355
Handle blockchain reorganization by deleting affected entries from LMDB.
356356
357-
LMDB's key-value architecture requires iterating through entries to find
358-
and delete affected data based on the metadata stored in each value.
357+
Uses the _amp_batch_id column for fast deletion of affected batches.
359358
360359
Args:
361360
invalidation_ranges: List of block ranges to invalidate (reorg points)
362361
table_name: The table containing the data to invalidate
362+
connection_name: The connection name (for state invalidation)
363363
"""
364364
if not invalidation_ranges:
365365
return
366366

367367
try:
368+
# Get affected batch IDs from state store
369+
all_affected_batch_ids = []
370+
for range_obj in invalidation_ranges:
371+
affected_batch_ids = self.state_store.invalidate_from_block(
372+
connection_name, table_name, range_obj.network, range_obj.start
373+
)
374+
all_affected_batch_ids.extend(affected_batch_ids)
375+
376+
if not all_affected_batch_ids:
377+
self.logger.info('No batches found to invalidate')
378+
return
379+
380+
batch_id_set = {bid.unique_id for bid in all_affected_batch_ids}
381+
368382
db = self._get_or_create_db(self.config.database_name)
369383
deleted_count = 0
370384

371385
with self.env.begin(write=True, db=db) as txn:
372386
cursor = txn.cursor()
373387
keys_to_delete = []
374388

375-
# First pass: identify keys to delete
389+
# First pass: identify keys to delete based on batch_id
376390
if cursor.first():
377391
while True:
378392
key = cursor.key()
379393
value = cursor.value()
380394

381-
# Deserialize the Arrow batch to check metadata
395+
# Deserialize the Arrow batch to check batch_id
382396
try:
383397
# Read the serialized Arrow batch
384398
reader = pa.ipc.open_stream(value)
385399
batch = reader.read_next_batch()
386400

387-
# Check if this batch has metadata column
388-
if '_meta_block_ranges' in batch.schema.names:
389-
# Get the metadata (should be a single row)
390-
meta_idx = batch.schema.get_field_index('_meta_block_ranges')
391-
meta_json = batch.column(meta_idx)[0].as_py()
392-
393-
if meta_json:
394-
try:
395-
ranges_data = json.loads(meta_json)
396-
397-
# Ensure ranges_data is a list
398-
if not isinstance(ranges_data, list):
399-
continue
400-
401-
# Check each invalidation range
402-
for range_obj in invalidation_ranges:
403-
network = range_obj.network
404-
reorg_start = range_obj.start
405-
406-
# Check if any range for this network should be invalidated
407-
for range_info in ranges_data:
408-
if (
409-
isinstance(range_info, dict)
410-
and range_info.get('network') == network
411-
and range_info.get('end', 0) >= reorg_start
412-
):
413-
keys_to_delete.append(key)
414-
deleted_count += 1
415-
break
416-
417-
if key in keys_to_delete:
418-
break
419-
420-
except (json.JSONDecodeError, KeyError):
421-
pass
401+
# Check if this batch has batch_id column
402+
if '_amp_batch_id' in batch.schema.names:
403+
# Get the batch_id (should be a single row)
404+
batch_id_idx = batch.schema.get_field_index('_amp_batch_id')
405+
batch_id_str = batch.column(batch_id_idx)[0].as_py()
406+
407+
if batch_id_str:
408+
# Check if any of the batch IDs match affected batches
409+
for batch_id in batch_id_str.split('|'):
410+
if batch_id in batch_id_set:
411+
keys_to_delete.append(key)
412+
deleted_count += 1
413+
break
422414

423415
except Exception as e:
424416
self.logger.debug(f'Failed to deserialize entry: {e}')

0 commit comments

Comments
 (0)