Skip to content

Commit 17718e2

Browse files
committed
snowflake loader: Add reorg aware streaming support
1 parent e6ba334 commit 17718e2

File tree

2 files changed

+283
-3
lines changed

2 files changed

+283
-3
lines changed

src/amp/loaders/implementations/snowflake_loader.py

Lines changed: 84 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
import io
22
import time
33
from dataclasses import dataclass
4-
from typing import Any, Dict, Optional
4+
from typing import Any, Dict, List, Optional
55

66
import pyarrow as pa
77
import pyarrow.csv as pa_csv
88
import snowflake.connector
99
from snowflake.connector import DictCursor, SnowflakeConnection
1010

11+
from ...streaming.types import BlockRange
1112
from ..base import DataLoader, LoadMode
1213

1314

@@ -390,7 +391,7 @@ def get_table_info(self, table_name: str) -> Optional[Dict[str, Any]]:
390391
# Get table metadata
391392
self.cursor.execute(
392393
"""
393-
SELECT
394+
SELECT
394395
TABLE_NAME,
395396
TABLE_SCHEMA,
396397
TABLE_CATALOG,
@@ -412,7 +413,7 @@ def get_table_info(self, table_name: str) -> Optional[Dict[str, Any]]:
412413
# Get column information
413414
self.cursor.execute(
414415
"""
415-
SELECT
416+
SELECT
416417
COLUMN_NAME,
417418
DATA_TYPE,
418419
IS_NULLABLE,
@@ -456,3 +457,83 @@ def get_table_info(self, table_name: str) -> Optional[Dict[str, Any]]:
456457
except Exception as e:
457458
self.logger.error(f"Failed to get table info for '{table_name}': {str(e)}")
458459
return None
460+
461+
def _handle_reorg(self, invalidation_ranges: List[BlockRange], table_name: str) -> None:
462+
"""
463+
Handle blockchain reorganization by deleting affected rows from Snowflake.
464+
465+
Snowflake's SQL capabilities allow for efficient deletion using JSON functions
466+
to parse the _meta_block_ranges column and identify affected rows.
467+
468+
Args:
469+
invalidation_ranges: List of block ranges to invalidate (reorg points)
470+
table_name: The table containing the data to invalidate
471+
"""
472+
if not invalidation_ranges:
473+
return
474+
475+
try:
476+
# First check if the table has the metadata column
477+
self.cursor.execute(
478+
"""
479+
SELECT COUNT(*) as count
480+
FROM INFORMATION_SCHEMA.COLUMNS
481+
WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ? AND COLUMN_NAME = '_META_BLOCK_RANGES'
482+
""",
483+
(self.config.schema, table_name.upper()),
484+
)
485+
486+
result = self.cursor.fetchone()
487+
if not result or result['COUNT'] == 0:
488+
self.logger.warning(
489+
f"Table '{table_name}' doesn't have '_meta_block_ranges' column, skipping reorg handling"
490+
)
491+
return
492+
493+
# Build DELETE statement with conditions for each invalidation range
494+
# Snowflake's PARSE_JSON and ARRAY_SIZE functions help work with JSON data
495+
delete_conditions = []
496+
497+
for range_obj in invalidation_ranges:
498+
network = range_obj.network
499+
reorg_start = range_obj.start
500+
501+
# Create condition for this network's reorg
502+
# Delete rows where any range in the JSON array for this network has end >= reorg_start
503+
condition = f"""
504+
EXISTS (
505+
SELECT 1
506+
FROM TABLE(FLATTEN(input => PARSE_JSON("_META_BLOCK_RANGES"))) f
507+
WHERE f.value:network::STRING = '{network}'
508+
AND f.value:end::NUMBER >= {reorg_start}
509+
)
510+
"""
511+
delete_conditions.append(condition)
512+
513+
# Combine conditions with OR
514+
if delete_conditions:
515+
where_clause = ' OR '.join(f'({cond})' for cond in delete_conditions)
516+
517+
# Execute deletion
518+
delete_sql = f'DELETE FROM {table_name} WHERE {where_clause}'
519+
520+
self.logger.info(
521+
f'Executing blockchain reorg deletion for {len(invalidation_ranges)} networks '
522+
f"in Snowflake table '{table_name}'"
523+
)
524+
525+
# Execute the delete and get row count
526+
self.cursor.execute(delete_sql)
527+
deleted_rows = self.cursor.rowcount
528+
529+
# Commit the transaction
530+
self.connection.commit()
531+
532+
self.logger.info(f"Blockchain reorg deleted {deleted_rows} rows from table '{table_name}'")
533+
534+
except Exception as e:
535+
self.logger.error(f"Failed to handle blockchain reorg for table '{table_name}': {str(e)}")
536+
# Rollback on error
537+
if self.connection:
538+
self.connection.rollback()
539+
raise

tests/integration/test_snowflake_loader.py

Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,3 +387,202 @@ def test_schema_with_special_characters(self, snowflake_config, test_table_name,
387387
assert row['first name'] == 'Alice'
388388
assert abs(row['total$amount'] - 100.0) < 0.001
389389
assert row['2024_data'] == 'a'
390+
391+
def test_handle_reorg_no_metadata_column(self, snowflake_config, test_table_name, cleanup_tables):
392+
"""Test reorg handling when table lacks metadata column"""
393+
from src.amp.streaming.types import BlockRange
394+
395+
cleanup_tables.append(test_table_name)
396+
loader = SnowflakeLoader(snowflake_config)
397+
398+
with loader:
399+
# Create table without metadata column
400+
data = pa.table({'id': [1, 2, 3], 'block_num': [100, 150, 200], 'value': [10.0, 20.0, 30.0]})
401+
loader.load_table(data, test_table_name, create_table=True)
402+
403+
# Call handle reorg
404+
invalidation_ranges = [BlockRange(network='ethereum', start=150, end=250)]
405+
406+
# Should log warning and not modify data
407+
loader._handle_reorg(invalidation_ranges, test_table_name)
408+
409+
# Verify data unchanged
410+
loader.cursor.execute(f'SELECT COUNT(*) FROM {test_table_name}')
411+
count = loader.cursor.fetchone()['COUNT(*)']
412+
assert count == 3
413+
414+
def test_handle_reorg_single_network(self, snowflake_config, test_table_name, cleanup_tables):
415+
"""Test reorg handling for single network data"""
416+
import json
417+
418+
from src.amp.streaming.types import BlockRange
419+
420+
cleanup_tables.append(test_table_name)
421+
loader = SnowflakeLoader(snowflake_config)
422+
423+
with loader:
424+
# Create table with metadata
425+
block_ranges = [
426+
[{'network': 'ethereum', 'start': 100, 'end': 110}],
427+
[{'network': 'ethereum', 'start': 150, 'end': 160}],
428+
[{'network': 'ethereum', 'start': 200, 'end': 210}],
429+
]
430+
431+
data = pa.table(
432+
{
433+
'id': [1, 2, 3],
434+
'block_num': [105, 155, 205],
435+
'_meta_block_ranges': [json.dumps(ranges) for ranges in block_ranges],
436+
}
437+
)
438+
439+
# Load initial data
440+
result = loader.load_table(data, test_table_name, create_table=True)
441+
assert result.success
442+
assert result.rows_loaded == 3
443+
444+
# Verify all data exists
445+
loader.cursor.execute(f'SELECT COUNT(*) FROM {test_table_name}')
446+
count = loader.cursor.fetchone()['COUNT(*)']
447+
assert count == 3
448+
449+
# Reorg from block 155 - should delete rows 2 and 3
450+
invalidation_ranges = [BlockRange(network='ethereum', start=155, end=300)]
451+
loader._handle_reorg(invalidation_ranges, test_table_name)
452+
453+
# Verify only first row remains
454+
loader.cursor.execute(f'SELECT COUNT(*) FROM {test_table_name}')
455+
count = loader.cursor.fetchone()['COUNT(*)']
456+
assert count == 1
457+
458+
loader.cursor.execute(f'SELECT id FROM {test_table_name}')
459+
remaining_id = loader.cursor.fetchone()['ID']
460+
assert remaining_id == 1
461+
462+
def test_handle_reorg_multi_network(self, snowflake_config, test_table_name, cleanup_tables):
463+
"""Test reorg handling preserves data from unaffected networks"""
464+
import json
465+
466+
from src.amp.streaming.types import BlockRange
467+
468+
cleanup_tables.append(test_table_name)
469+
loader = SnowflakeLoader(snowflake_config)
470+
471+
with loader:
472+
# Create data from multiple networks
473+
block_ranges = [
474+
[{'network': 'ethereum', 'start': 100, 'end': 110}],
475+
[{'network': 'polygon', 'start': 100, 'end': 110}],
476+
[{'network': 'ethereum', 'start': 150, 'end': 160}],
477+
[{'network': 'polygon', 'start': 150, 'end': 160}],
478+
]
479+
480+
data = pa.table(
481+
{
482+
'id': [1, 2, 3, 4],
483+
'network': ['ethereum', 'polygon', 'ethereum', 'polygon'],
484+
'_meta_block_ranges': [json.dumps([r]) for r in block_ranges],
485+
}
486+
)
487+
488+
# Load initial data
489+
result = loader.load_table(data, test_table_name, create_table=True)
490+
assert result.success
491+
assert result.rows_loaded == 4
492+
493+
# Reorg only ethereum from block 150
494+
invalidation_ranges = [BlockRange(network='ethereum', start=150, end=200)]
495+
loader._handle_reorg(invalidation_ranges, test_table_name)
496+
497+
# Verify ethereum row 3 deleted, but polygon rows preserved
498+
loader.cursor.execute(f'SELECT id FROM {test_table_name} ORDER BY id')
499+
remaining_ids = [row['ID'] for row in loader.cursor.fetchall()]
500+
assert remaining_ids == [1, 2, 4] # Row 3 deleted
501+
502+
def test_handle_reorg_overlapping_ranges(self, snowflake_config, test_table_name, cleanup_tables):
503+
"""Test reorg with overlapping block ranges"""
504+
import json
505+
506+
from src.amp.streaming.types import BlockRange
507+
508+
cleanup_tables.append(test_table_name)
509+
loader = SnowflakeLoader(snowflake_config)
510+
511+
with loader:
512+
# Create data with overlapping ranges
513+
block_ranges = [
514+
[{'network': 'ethereum', 'start': 90, 'end': 110}], # Overlaps with reorg
515+
[{'network': 'ethereum', 'start': 140, 'end': 160}], # Overlaps with reorg
516+
[{'network': 'ethereum', 'start': 170, 'end': 190}], # After reorg
517+
]
518+
519+
data = pa.table({'id': [1, 2, 3], '_meta_block_ranges': [json.dumps(ranges) for ranges in block_ranges]})
520+
521+
# Load initial data
522+
result = loader.load_table(data, test_table_name, create_table=True)
523+
assert result.success
524+
assert result.rows_loaded == 3
525+
526+
# Reorg from block 150 - should delete rows where end >= 150
527+
invalidation_ranges = [BlockRange(network='ethereum', start=150, end=200)]
528+
loader._handle_reorg(invalidation_ranges, test_table_name)
529+
530+
# Only first row should remain (ends at 110 < 150)
531+
loader.cursor.execute(f'SELECT COUNT(*) FROM {test_table_name}')
532+
count = loader.cursor.fetchone()['COUNT(*)']
533+
assert count == 1
534+
535+
loader.cursor.execute(f'SELECT id FROM {test_table_name}')
536+
remaining_id = loader.cursor.fetchone()['ID']
537+
assert remaining_id == 1
538+
539+
def test_streaming_with_reorg(self, snowflake_config, test_table_name, cleanup_tables):
540+
"""Test streaming data with reorg support"""
541+
from src.amp.streaming.types import BatchMetadata, BlockRange, ResponseBatch, ResponseBatchWithReorg
542+
543+
cleanup_tables.append(test_table_name)
544+
loader = SnowflakeLoader(snowflake_config)
545+
546+
with loader:
547+
# Create streaming data with metadata
548+
data1 = pa.RecordBatch.from_pydict({'id': [1, 2], 'value': [100, 200]})
549+
550+
data2 = pa.RecordBatch.from_pydict({'id': [3, 4], 'value': [300, 400]})
551+
552+
# Create response batches
553+
response1 = ResponseBatchWithReorg(
554+
is_reorg=False,
555+
data=ResponseBatch(
556+
data=data1, metadata=BatchMetadata(ranges=[BlockRange(network='ethereum', start=100, end=110)])
557+
),
558+
)
559+
560+
response2 = ResponseBatchWithReorg(
561+
is_reorg=False,
562+
data=ResponseBatch(
563+
data=data2, metadata=BatchMetadata(ranges=[BlockRange(network='ethereum', start=150, end=160)])
564+
),
565+
)
566+
567+
# Simulate reorg event
568+
reorg_response = ResponseBatchWithReorg(
569+
is_reorg=True, invalidation_ranges=[BlockRange(network='ethereum', start=150, end=200)]
570+
)
571+
572+
# Process streaming data
573+
stream = [response1, response2, reorg_response]
574+
results = list(loader.load_stream_continuous(iter(stream), test_table_name))
575+
576+
# Verify results
577+
assert len(results) == 3
578+
assert results[0].success
579+
assert results[0].rows_loaded == 2
580+
assert results[1].success
581+
assert results[1].rows_loaded == 2
582+
assert results[2].success
583+
assert results[2].is_reorg
584+
585+
# Verify reorg deleted the second batch
586+
loader.cursor.execute(f'SELECT id FROM {test_table_name} ORDER BY id')
587+
remaining_ids = [row['ID'] for row in loader.cursor.fetchall()]
588+
assert remaining_ids == [1, 2] # 3 and 4 deleted by reorg

0 commit comments

Comments
 (0)