Skip to content

Commit b4b1007

Browse files
committed
WIP: snowpipe fixes/tests and label manager
1 parent 75f7a7b commit b4b1007

File tree

5 files changed

+280
-106
lines changed

5 files changed

+280
-106
lines changed

src/amp/client.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from . import FlightSql_pb2
99
from .config.connection_manager import ConnectionManager
10+
from .config.label_manager import LabelManager
1011
from .loaders.registry import create_loader, get_available_loaders
1112
from .loaders.types import LoadConfig, LoadMode, LoadResult
1213
from .streaming import (
@@ -105,6 +106,7 @@ class Client:
105106
def __init__(self, url):
106107
self.conn = flight.connect(url)
107108
self.connection_manager = ConnectionManager()
109+
self.label_manager = LabelManager()
108110
self.logger = logging.getLogger(__name__)
109111

110112
def sql(self, query: str) -> QueryBuilder:
@@ -123,6 +125,18 @@ def configure_connection(self, name: str, loader: str, config: Dict[str, Any]) -
123125
"""Configure a named connection for reuse"""
124126
self.connection_manager.add_connection(name, loader, config)
125127

128+
def configure_label(self, name: str, csv_path: str, binary_columns: Optional[List[str]] = None) -> None:
129+
"""
130+
Configure a label dataset from a CSV file for joining with streaming data.
131+
132+
Args:
133+
name: Unique name for this label dataset
134+
csv_path: Path to the CSV file
135+
binary_columns: List of column names containing hex addresses to convert to binary.
136+
If None, auto-detects columns with 'address' in the name.
137+
"""
138+
self.label_manager.add_label(name, csv_path, binary_columns)
139+
126140
def list_connections(self) -> Dict[str, str]:
127141
"""List all configured connections"""
128142
return self.connection_manager.list_connections()
@@ -238,7 +252,7 @@ def _load_table(
238252
) -> LoadResult:
239253
"""Load a complete Arrow Table"""
240254
try:
241-
loader_instance = create_loader(loader, config)
255+
loader_instance = create_loader(loader, config, label_manager=self.label_manager)
242256

243257
with loader_instance:
244258
return loader_instance.load_table(table, table_name, **load_config.__dict__, **kwargs)
@@ -265,7 +279,7 @@ def _load_stream(
265279
) -> Iterator[LoadResult]:
266280
"""Load from a stream of batches"""
267281
try:
268-
loader_instance = create_loader(loader, config)
282+
loader_instance = create_loader(loader, config, label_manager=self.label_manager)
269283

270284
with loader_instance:
271285
yield from loader_instance.load_stream(batch_stream, table_name, **load_config.__dict__, **kwargs)
@@ -355,7 +369,7 @@ def query_and_load_streaming(
355369
self.logger.info(f'Starting streaming query to {loader_type}:{destination}')
356370

357371
# Create loader instance early to access checkpoint store
358-
loader_instance = create_loader(loader_type, loader_config)
372+
loader_instance = create_loader(loader_type, loader_config, label_manager=self.label_manager)
359373

360374
# Load checkpoint and create resume watermark if enabled (default: enabled)
361375
if resume_watermark is None and kwargs.get('resume', True):

src/amp/loaders/base.py

Lines changed: 150 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,12 @@ class DataLoader(ABC, Generic[TConfig]):
5050
REQUIRES_SCHEMA_MATCH: bool = True
5151
SUPPORTS_TRANSACTIONS: bool = False
5252

53-
def __init__(self, config: Dict[str, Any]) -> None:
53+
def __init__(self, config: Dict[str, Any], label_manager=None) -> None:
5454
self.logger: Logger = logging.getLogger(f'{self.__class__.__name__}')
5555
self._connection: Optional[Any] = None
5656
self._is_connected: bool = False
5757
self._created_tables: Set[str] = set() # Track created tables
58+
self.label_manager = label_manager # For CSV label joining
5859

5960
# Parse configuration into typed format
6061
self.config: TConfig = self._parse_config(config)
@@ -240,6 +241,7 @@ def _try_load_batch(self, batch: pa.RecordBatch, table_name: str, **kwargs) -> L
240241
This is called by load_batch() within the retry loop. It handles:
241242
- Connection management
242243
- Mode validation
244+
- Label joining (if configured)
243245
- Table creation
244246
- Error handling and timing
245247
- Metadata generation
@@ -258,7 +260,26 @@ def _try_load_batch(self, batch: pa.RecordBatch, table_name: str, **kwargs) -> L
258260
if mode not in self.SUPPORTED_MODES:
259261
raise ValueError(f'Unsupported mode {mode}. Supported modes: {self.SUPPORTED_MODES}')
260262

261-
# Handle table creation
263+
# Apply label joining if requested
264+
label_name = kwargs.get('label')
265+
label_key_column = kwargs.get('label_key_column')
266+
stream_key_column = kwargs.get('stream_key_column')
267+
268+
if label_name or label_key_column or stream_key_column:
269+
# If any label param is provided, all must be provided
270+
if not (label_name and label_key_column and stream_key_column):
271+
raise ValueError(
272+
'Label joining requires all three parameters: label, label_key_column, stream_key_column'
273+
)
274+
275+
# Perform the join
276+
batch = self._join_with_labels(batch, label_name, label_key_column, stream_key_column)
277+
self.logger.debug(
278+
f'Joined batch with label {label_name}: {batch.num_rows} rows after join '
279+
f'(columns: {", ".join(batch.schema.names)})'
280+
)
281+
282+
# Handle table creation (use joined schema if applicable)
262283
if kwargs.get('create_table', True) and table_name not in self._created_tables:
263284
if hasattr(self, '_create_table_from_schema'):
264285
self._create_table_from_schema(batch.schema, table_name)
@@ -891,6 +912,133 @@ def _get_loader_table_metadata(
891912
"""Override in subclasses to add loader-specific table metadata"""
892913
return {}
893914

915+
def _get_effective_schema(
916+
self, original_schema: pa.Schema, label_name: Optional[str], label_key_column: Optional[str]
917+
) -> pa.Schema:
918+
"""
919+
Get effective schema by merging label columns into original schema.
920+
921+
If label_name is None, returns original schema unchanged.
922+
Otherwise, merges label columns (excluding the join key which is already in original).
923+
924+
Args:
925+
original_schema: Original data schema
926+
label_name: Name of the label dataset (None if no labels)
927+
label_key_column: Column name in the label table to join on
928+
929+
Returns:
930+
Schema with label columns merged in
931+
"""
932+
if label_name is None or label_key_column is None:
933+
return original_schema
934+
935+
if self.label_manager is None:
936+
raise ValueError('Label manager not configured')
937+
938+
label_table = self.label_manager.get_label(label_name)
939+
if label_table is None:
940+
raise ValueError(f"Label '{label_name}' not found")
941+
942+
# Start with original schema fields
943+
merged_fields = list(original_schema)
944+
945+
# Add label columns (excluding the join key which is already in original)
946+
for field in label_table.schema:
947+
if field.name != label_key_column and field.name not in original_schema.names:
948+
merged_fields.append(field)
949+
950+
return pa.schema(merged_fields)
951+
952+
def _join_with_labels(
953+
self, batch: pa.RecordBatch, label_name: str, label_key_column: str, stream_key_column: str
954+
) -> pa.RecordBatch:
955+
"""
956+
Join batch data with labels using inner join.
957+
958+
Handles automatic type conversion between stream and label key columns
959+
(e.g., string ↔ binary for Ethereum addresses).
960+
961+
Args:
962+
batch: Original data batch
963+
label_name: Name of the label dataset
964+
label_key_column: Column name in the label table to join on
965+
stream_key_column: Column name in the batch data to join on
966+
967+
Returns:
968+
Joined RecordBatch with label columns added
969+
970+
Raises:
971+
ValueError: If label_manager not configured, label not found, or invalid columns
972+
"""
973+
if self.label_manager is None:
974+
raise ValueError('Label manager not configured')
975+
976+
label_table = self.label_manager.get_label(label_name)
977+
if label_table is None:
978+
raise ValueError(f"Label '{label_name}' not found")
979+
980+
# Validate columns exist
981+
if stream_key_column not in batch.schema.names:
982+
raise ValueError(f"Stream key column '{stream_key_column}' not found in batch schema")
983+
984+
if label_key_column not in label_table.schema.names:
985+
raise ValueError(f"Label key column '{label_key_column}' not found in label table")
986+
987+
# Convert batch to table for join operation
988+
batch_table = pa.Table.from_batches([batch])
989+
990+
# Get column types for join keys
991+
stream_key_type = batch_table.schema.field(stream_key_column).type
992+
label_key_type = label_table.schema.field(label_key_column).type
993+
994+
# If types don't match, cast one to match the other
995+
# Prefer casting to binary since that's more efficient
996+
import pyarrow.compute as pc
997+
998+
if stream_key_type != label_key_type:
999+
# Try to cast stream key to label key type
1000+
if pa.types.is_fixed_size_binary(label_key_type) and pa.types.is_string(stream_key_type):
1001+
# Cast string to binary (hex strings like "0xABCD...")
1002+
def hex_to_binary(value):
1003+
if value is None:
1004+
return None
1005+
# Remove 0x prefix if present
1006+
hex_str = value[2:] if value.startswith('0x') else value
1007+
return bytes.fromhex(hex_str)
1008+
1009+
# Cast the stream column to binary
1010+
stream_column = batch_table.column(stream_key_column)
1011+
binary_length = label_key_type.byte_width
1012+
binary_values = pa.array(
1013+
[hex_to_binary(v.as_py()) for v in stream_column], type=pa.binary(binary_length)
1014+
)
1015+
batch_table = batch_table.set_column(
1016+
batch_table.schema.get_field_index(stream_key_column), stream_key_column, binary_values
1017+
)
1018+
elif pa.types.is_binary(stream_key_type) and pa.types.is_string(label_key_type):
1019+
# Cast binary to string (for test compatibility)
1020+
stream_column = batch_table.column(stream_key_column)
1021+
string_values = pa.array([v.as_py().hex() if v.as_py() else None for v in stream_column])
1022+
batch_table = batch_table.set_column(
1023+
batch_table.schema.get_field_index(stream_key_column), stream_key_column, string_values
1024+
)
1025+
1026+
# Perform inner join using PyArrow compute
1027+
# Inner join will filter out rows where stream key doesn't match any label key
1028+
joined_table = batch_table.join(
1029+
label_table, keys=stream_key_column, right_keys=label_key_column, join_type='inner'
1030+
)
1031+
1032+
# Convert back to RecordBatch
1033+
if joined_table.num_rows == 0:
1034+
# Empty result - return empty batch with joined schema
1035+
# Need to create empty arrays for each column
1036+
empty_data = {field.name: pa.array([], type=field.type) for field in joined_table.schema}
1037+
return pa.RecordBatch.from_pydict(empty_data, schema=joined_table.schema)
1038+
1039+
# Return as a single batch (assuming batch sizes are manageable)
1040+
return joined_table.to_batches()[0]
1041+
8941042
def __enter__(self) -> 'DataLoader':
8951043
self.connect()
8961044
return self

src/amp/loaders/implementations/snowflake_loader.py

Lines changed: 36 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -351,8 +351,8 @@ class SnowflakeLoader(DataLoader[SnowflakeConnectionConfig]):
351351
REQUIRES_SCHEMA_MATCH = False
352352
SUPPORTS_TRANSACTIONS = True
353353

354-
def __init__(self, config: Dict[str, Any]) -> None:
355-
super().__init__(config)
354+
def __init__(self, config: Dict[str, Any], label_manager=None) -> None:
355+
super().__init__(config, label_manager=label_manager)
356356
self.connection: Optional[SnowflakeConnection] = None
357357
self.cursor = None
358358
self._created_tables = set() # Track created tables
@@ -625,9 +625,9 @@ def disconnect(self) -> None:
625625
for channel_key, channel in self.streaming_channels.items():
626626
try:
627627
channel.close()
628-
self.logger.debug(f'Closed channel: {channel.name}')
628+
self.logger.debug(f'Closed channel: {channel_key}')
629629
except Exception as e:
630-
self.logger.warning(f'Error closing channel: {e}')
630+
self.logger.warning(f'Error closing channel {channel_key}: {e}')
631631

632632
self.streaming_channels.clear()
633633

@@ -736,13 +736,19 @@ def _load_via_stage(self, batch: pa.RecordBatch, table_name: str) -> int:
736736

737737
# Identify binary columns and convert to hex for CSV compatibility
738738
binary_columns = {}
739+
# Track VARIANT columns so we can use PARSE_JSON in COPY INTO
740+
variant_columns = set()
739741
modified_arrays = []
740742
modified_fields = []
741743

742744
t_conversion_start = time.time()
743745
for i, field in enumerate(batch.schema):
744746
col_array = batch.column(i)
745747

748+
# Track _meta_block_ranges as VARIANT column for JSON parsing
749+
if field.name == '_meta_block_ranges':
750+
variant_columns.add(field.name)
751+
746752
# Check if this is a binary type that needs hex encoding
747753
if pa.types.is_binary(field.type) or pa.types.is_large_binary(field.type) or pa.types.is_fixed_size_binary(field.type):
748754
binary_columns[field.name] = field.type
@@ -801,12 +807,15 @@ def _load_via_stage(self, batch: pa.RecordBatch, table_name: str) -> int:
801807
t_put_end = time.time()
802808
self.logger.debug(f'PUT command took {t_put_end - t_put_start:.2f}s')
803809

804-
# Build column list with transformations - convert hex strings back to binary
810+
# Build column list with transformations - convert hex strings back to binary, parse JSON for VARIANT
805811
final_column_specs = []
806812
for i, field in enumerate(batch.schema, start=1):
807813
if field.name in binary_columns:
808814
# Use TO_BINARY to convert hex string back to binary
809815
final_column_specs.append(f'TO_BINARY(${i}, \'HEX\')')
816+
elif field.name in variant_columns:
817+
# Use PARSE_JSON to convert JSON string to VARIANT
818+
final_column_specs.append(f'PARSE_JSON(${i})')
810819
else:
811820
final_column_specs.append(f'${i}')
812821

@@ -1468,9 +1477,9 @@ def _handle_reorg(self, invalidation_ranges: List[BlockRange], table_name: str)
14681477
try:
14691478
channel.close()
14701479
del self.streaming_channels[channel_key]
1471-
self.logger.debug(f'Closed streaming channel: {channel.name}')
1480+
self.logger.debug(f'Closed streaming channel: {channel_key}')
14721481
except Exception as e:
1473-
self.logger.warning(f'Error closing channel {channel.name}: {e}')
1482+
self.logger.warning(f'Error closing channel {channel_key}: {e}')
14741483
# Continue closing other channels even if one fails
14751484

14761485
self.logger.info(
@@ -1482,7 +1491,7 @@ def _handle_reorg(self, invalidation_ranges: List[BlockRange], table_name: str)
14821491
"""
14831492
SELECT COUNT(*) as count
14841493
FROM INFORMATION_SCHEMA.COLUMNS
1485-
WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ? AND COLUMN_NAME = '_META_BLOCK_RANGES'
1494+
WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ? AND COLUMN_NAME = '_meta_block_ranges'
14861495
""",
14871496
(self.config.schema, table_name.upper()),
14881497
)
@@ -1494,32 +1503,33 @@ def _handle_reorg(self, invalidation_ranges: List[BlockRange], table_name: str)
14941503
)
14951504
return
14961505

1497-
# Build DELETE statement with conditions for each invalidation range
1498-
# Snowflake's PARSE_JSON and ARRAY_SIZE functions help work with JSON data
1499-
delete_conditions = []
1506+
# Build WHERE conditions for FLATTEN-based deletion
1507+
# Since Snowflake doesn't support complex subqueries in DELETE WHERE,
1508+
# we use a CTE-based approach with row identification
1509+
where_conditions = []
15001510

15011511
for range_obj in invalidation_ranges:
15021512
network = range_obj.network
15031513
reorg_start = range_obj.start
15041514

15051515
# Create condition for this network's reorg
1506-
# Delete rows where any range in the JSON array for this network has end >= reorg_start
1507-
condition = f"""
1508-
EXISTS (
1509-
SELECT 1
1510-
FROM TABLE(FLATTEN(input => PARSE_JSON("_META_BLOCK_RANGES"))) f
1511-
WHERE f.value:network::STRING = '{network}'
1512-
AND f.value:end::NUMBER >= {reorg_start}
1516+
where_conditions.append(f"""
1517+
(f.value:network::STRING = '{network}' AND f.value:end::NUMBER >= {reorg_start})
1518+
""")
1519+
1520+
if where_conditions:
1521+
# Use a CTE to identify rows to delete, then delete using METADATA$ROW_ID
1522+
where_clause = ' OR '.join(where_conditions)
1523+
1524+
# Create DELETE SQL using CTE for row identification
1525+
delete_sql = f"""
1526+
DELETE FROM {table_name}
1527+
WHERE "_meta_block_ranges" IN (
1528+
SELECT DISTINCT "_meta_block_ranges"
1529+
FROM {table_name}, LATERAL FLATTEN(input => "_meta_block_ranges") f
1530+
WHERE {where_clause}
15131531
)
15141532
"""
1515-
delete_conditions.append(condition)
1516-
1517-
# Combine conditions with OR
1518-
if delete_conditions:
1519-
where_clause = ' OR '.join(f'({cond})' for cond in delete_conditions)
1520-
1521-
# Execute deletion
1522-
delete_sql = f'DELETE FROM {table_name} WHERE {where_clause}'
15231533

15241534
self.logger.info(
15251535
f'Executing blockchain reorg deletion for {len(invalidation_ranges)} networks '

0 commit comments

Comments
 (0)