Skip to content

Commit 18a83b8

Browse files
committed
label manager and parallel load test
1 parent a121d3e commit 18a83b8

File tree

2 files changed

+335
-0
lines changed

2 files changed

+335
-0
lines changed

apps/test_erc20_parallel_load.py

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Real-world test: Load ERC20 transfers into Snowflake using parallel streaming.
4+
5+
Usage:
6+
python app/test_erc20_parallel_load.py [--blocks BLOCKS] [--workers WORKERS]
7+
8+
Example:
9+
python app/test_erc20_parallel_load.py --blocks 100000 --workers 8
10+
"""
11+
12+
import argparse
13+
import os
14+
import time
15+
from datetime import datetime
16+
17+
from amp.client import Client
18+
from amp.streaming.parallel import ParallelConfig
19+
20+
21+
def get_recent_block_range(client: Client, num_blocks: int = 100_000):
22+
"""Query amp server to get recent block range."""
23+
print(f'\n🔍 Detecting recent block range ({num_blocks:,} blocks)...')
24+
25+
query = 'SELECT MAX(block_num) as max_block FROM eth_firehose.logs'
26+
result = client.get_sql(query, read_all=True)
27+
28+
if result.num_rows == 0:
29+
raise RuntimeError('No data found in eth_firehose.logs')
30+
31+
max_block = result.column('max_block')[0].as_py()
32+
if max_block is None:
33+
raise RuntimeError('No blocks found in eth_firehose.logs')
34+
35+
min_block = max(0, max_block - num_blocks)
36+
37+
print(f'✅ Block range: {min_block:,} to {max_block:,} ({max_block - min_block:,} blocks)')
38+
return min_block, max_block
39+
40+
41+
def load_erc20_transfers(num_blocks: int = 100_000, num_workers: int = 8):
42+
"""Load ERC20 transfers using parallel streaming."""
43+
44+
# Initialize client
45+
server_url = os.getenv('AMP_SERVER_URL', 'grpc://34.27.238.174:80')
46+
client = Client(server_url)
47+
print(f'📡 Connected to amp server: {server_url}')
48+
49+
# Get recent block range
50+
min_block, max_block = get_recent_block_range(client, num_blocks)
51+
52+
# Generate unique table name
53+
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
54+
table_name = f'erc20_transfers_{timestamp}'
55+
print(f'\n📊 Target table: {table_name}')
56+
57+
# ERC20 Transfer event signature
58+
transfer_sig = 'Transfer(address indexed from, address indexed to, uint256 value)'
59+
60+
# ERC20 transfer query with corrected syntax
61+
erc20_query = f"""
62+
select
63+
pc.block_num,
64+
pc.block_hash,
65+
pc.timestamp,
66+
pc.tx_hash,
67+
pc.tx_index,
68+
pc.log_index,
69+
pc.dec['from'] as from_address,
70+
pc.dec['to'] as to_address,
71+
pc.dec['value'] as value
72+
from (
73+
select
74+
l.block_num,
75+
l.block_hash,
76+
l.tx_hash,
77+
l.tx_index,
78+
l.log_index,
79+
l.timestamp,
80+
evm_decode(l.topic1, l.topic2, l.topic3, l.data, '{transfer_sig}') as dec
81+
from eth_firehose.logs l
82+
where
83+
l.topic0 = evm_topic('{transfer_sig}') and
84+
l.topic3 IS NULL) pc
85+
"""
86+
87+
# Configure Snowflake connection
88+
snowflake_config = {
89+
'account': os.getenv('SNOWFLAKE_ACCOUNT'),
90+
'user': os.getenv('SNOWFLAKE_USER'),
91+
'warehouse': os.getenv('SNOWFLAKE_WAREHOUSE'),
92+
'database': os.getenv('SNOWFLAKE_DATABASE'),
93+
'private_key': os.getenv('SNOWFLAKE_PRIVATE_KEY'),
94+
'loading_method': 'stage', # Use fast bulk loading via COPY INTO
95+
}
96+
97+
client.configure_connection(name='snowflake_erc20', loader='snowflake', config=snowflake_config)
98+
99+
# Configure parallel execution
100+
parallel_config = ParallelConfig(
101+
num_workers=num_workers,
102+
table_name='eth_firehose.logs',
103+
min_block=min_block,
104+
max_block=max_block,
105+
block_column='block_num',
106+
)
107+
108+
print(f'\n🚀 Starting parallel load with {num_workers} workers...\n')
109+
110+
start_time = time.time()
111+
112+
# Load data in parallel (will stop after processing the block range)
113+
results = list(
114+
client.sql(erc20_query).load(
115+
connection='snowflake_erc20', destination=table_name, stream=True, parallel_config=parallel_config
116+
)
117+
)
118+
119+
duration = time.time() - start_time
120+
121+
# Calculate statistics
122+
total_rows = sum(r.rows_loaded for r in results if r.success)
123+
rows_per_sec = total_rows / duration if duration > 0 else 0
124+
partitions = [r for r in results if 'partition_id' in r.metadata]
125+
successful_workers = len(partitions)
126+
failed_workers = num_workers - successful_workers
127+
128+
# Print results
129+
print(f'\n{"=" * 70}')
130+
print('🎉 ERC20 Parallel Load Complete!')
131+
print(f'{"=" * 70}')
132+
print(f'📊 Table name: {table_name}')
133+
print(f'📦 Block range: {min_block:,} to {max_block:,}')
134+
print(f'📈 Rows loaded: {total_rows:,}')
135+
print(f'⏱️ Duration: {duration:.2f}s')
136+
print(f'🚀 Throughput: {rows_per_sec:,.0f} rows/sec')
137+
print(f'👷 Workers: {successful_workers}/{num_workers} succeeded')
138+
if failed_workers > 0:
139+
print(f'⚠️ Failed workers: {failed_workers}')
140+
print(f'📊 Avg rows/block: {total_rows / (max_block - min_block):.0f}')
141+
print(f'{"=" * 70}')
142+
143+
print(f'\n✅ Table "{table_name}" is ready in Snowflake for testing!')
144+
print(f' Query it with: SELECT * FROM {table_name} LIMIT 10;')
145+
146+
return table_name, total_rows, duration
147+
148+
149+
if __name__ == '__main__':
150+
parser = argparse.ArgumentParser(description='Load ERC20 transfers into Snowflake using parallel streaming')
151+
parser.add_argument(
152+
'--blocks', type=int, default=100_000, help='Number of recent blocks to load (default: 100,000)'
153+
)
154+
parser.add_argument('--workers', type=int, default=8, help='Number of parallel workers (default: 8)')
155+
156+
args = parser.parse_args()
157+
158+
try:
159+
load_erc20_transfers(num_blocks=args.blocks, num_workers=args.workers)
160+
except KeyboardInterrupt:
161+
print('\n\n⚠️ Interrupted by user')
162+
except Exception as e:
163+
print(f'\n\n❌ Error: {e}')
164+
raise

src/amp/config/label_manager.py

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
"""
2+
Label Manager for CSV-based label datasets.
3+
4+
This module provides functionality to register and manage CSV label datasets
5+
that can be joined with streaming data during loading operations.
6+
"""
7+
8+
import logging
9+
from typing import Dict, List, Optional
10+
11+
import pyarrow as pa
12+
import pyarrow.csv as csv
13+
14+
15+
class LabelManager:
16+
"""
17+
Manages CSV label datasets for joining with streaming data.
18+
19+
Labels are registered by name and loaded as PyArrow Tables for efficient
20+
joining operations. This allows reuse of label datasets across multiple
21+
queries and loaders.
22+
23+
Example:
24+
>>> manager = LabelManager()
25+
>>> manager.add_label('token_labels', '/path/to/tokens.csv')
26+
>>> label_table = manager.get_label('token_labels')
27+
"""
28+
29+
def __init__(self):
30+
self._labels: Dict[str, pa.Table] = {}
31+
self.logger = logging.getLogger(__name__)
32+
33+
def add_label(self, name: str, csv_path: str, binary_columns: Optional[List[str]] = None) -> None:
34+
"""
35+
Load and register a CSV label dataset with automatic hex→binary conversion.
36+
37+
Hex string columns (like Ethereum addresses) are automatically converted to
38+
binary format for efficient storage and joining. This reduces memory usage
39+
by ~50% and improves join performance.
40+
41+
Args:
42+
name: Unique name for this label dataset
43+
csv_path: Path to the CSV file
44+
binary_columns: List of column names containing hex addresses to convert to binary.
45+
If None, auto-detects columns with 'address' in the name.
46+
47+
Raises:
48+
FileNotFoundError: If CSV file doesn't exist
49+
ValueError: If CSV cannot be parsed or name already exists
50+
"""
51+
if name in self._labels:
52+
self.logger.warning(f"Label '{name}' already exists, replacing with new data")
53+
54+
try:
55+
# Load CSV as PyArrow Table (initially as strings)
56+
temp_table = csv.read_csv(csv_path, read_options=csv.ReadOptions(autogenerate_column_names=False))
57+
58+
# Force all columns to be strings initially
59+
column_types = {col_name: pa.string() for col_name in temp_table.column_names}
60+
convert_opts = csv.ConvertOptions(column_types=column_types)
61+
label_table = csv.read_csv(csv_path, convert_options=convert_opts)
62+
63+
# Auto-detect or use specified binary columns
64+
if binary_columns is None:
65+
# Auto-detect columns with 'address' in name (case-insensitive)
66+
binary_columns = [col for col in label_table.column_names if 'address' in col.lower()]
67+
68+
# Convert hex string columns to binary for efficiency
69+
converted_columns = []
70+
for col_name in binary_columns:
71+
if col_name not in label_table.column_names:
72+
self.logger.warning(f"Binary column '{col_name}' not found in CSV, skipping")
73+
continue
74+
75+
hex_col = label_table.column(col_name)
76+
77+
# Detect hex string format and convert to binary
78+
# Sample first non-null value to determine format
79+
sample_value = None
80+
for v in hex_col.to_pylist()[:100]: # Check first 100 values
81+
if v is not None:
82+
sample_value = v
83+
break
84+
85+
if sample_value is None:
86+
self.logger.warning(f"Column '{col_name}' has no non-null values, skipping conversion")
87+
continue
88+
89+
# Detect if it's a hex string (with or without 0x prefix)
90+
if isinstance(sample_value, str) and all(c in '0123456789abcdefABCDEFx' for c in sample_value):
91+
# Determine binary length from hex string
92+
hex_str = sample_value[2:] if sample_value.startswith('0x') else sample_value
93+
binary_length = len(hex_str) // 2
94+
95+
# Convert all values to binary (fixed-size to match streaming data)
96+
def hex_to_binary(v):
97+
if v is None:
98+
return None
99+
hex_str = v[2:] if v.startswith('0x') else v
100+
return bytes.fromhex(hex_str)
101+
102+
binary_values = pa.array(
103+
[hex_to_binary(v) for v in hex_col.to_pylist()],
104+
type=pa.binary(
105+
binary_length
106+
), # Fixed-size binary to match server data (e.g., 20 bytes for addresses)
107+
)
108+
109+
# Replace the column
110+
label_table = label_table.set_column(
111+
label_table.schema.get_field_index(col_name), col_name, binary_values
112+
)
113+
converted_columns.append(f'{col_name} (hex→fixed_size_binary[{binary_length}])')
114+
self.logger.info(f"Converted '{col_name}' from hex string to fixed_size_binary[{binary_length}]")
115+
116+
self._labels[name] = label_table
117+
118+
conversion_info = f', converted: {", ".join(converted_columns)}' if converted_columns else ''
119+
self.logger.info(
120+
f"Loaded label '{name}' from {csv_path}: "
121+
f'{label_table.num_rows:,} rows, {len(label_table.schema)} columns '
122+
f'({", ".join(label_table.schema.names)}){conversion_info}'
123+
)
124+
125+
except FileNotFoundError:
126+
raise FileNotFoundError(f'Label CSV file not found: {csv_path}')
127+
except Exception as e:
128+
raise ValueError(f"Failed to load label CSV '{csv_path}': {e}") from e
129+
130+
def get_label(self, name: str) -> Optional[pa.Table]:
131+
"""
132+
Get label table by name.
133+
134+
Args:
135+
name: Name of the label dataset
136+
137+
Returns:
138+
PyArrow Table containing label data, or None if not found
139+
"""
140+
return self._labels.get(name)
141+
142+
def list_labels(self) -> List[str]:
143+
"""
144+
List all registered label names.
145+
146+
Returns:
147+
List of label names
148+
"""
149+
return list(self._labels.keys())
150+
151+
def remove_label(self, name: str) -> bool:
152+
"""
153+
Remove a label dataset.
154+
155+
Args:
156+
name: Name of the label to remove
157+
158+
Returns:
159+
True if label was removed, False if it didn't exist
160+
"""
161+
if name in self._labels:
162+
del self._labels[name]
163+
self.logger.info(f"Removed label '{name}'")
164+
return True
165+
return False
166+
167+
def clear(self) -> None:
168+
"""Remove all label datasets."""
169+
count = len(self._labels)
170+
self._labels.clear()
171+
self.logger.info(f'Cleared {count} label dataset(s)')

0 commit comments

Comments
 (0)