Skip to content

Commit 801548d

Browse files
committed
Formatting
1 parent 3f74af0 commit 801548d

22 files changed

+581
-637
lines changed

apps/snowflake_parallel_loader.py

Lines changed: 35 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,7 @@ def configure_logging(verbose: bool = False):
6161
"""
6262
# Configure root logger first
6363
logging.basicConfig(
64-
level=logging.INFO,
65-
format='%(asctime)s | %(levelname)s | %(name)s | %(message)s',
66-
datefmt='%Y-%m-%d %H:%M:%S'
64+
level=logging.INFO, format='%(asctime)s | %(levelname)s | %(name)s | %(message)s', datefmt='%Y-%m-%d %H:%M:%S'
6765
)
6866

6967
if not verbose:
@@ -223,8 +221,16 @@ def print_configuration(args, min_block: int, max_block: int, has_labels: bool):
223221
print(f'🏷️ Label Joining: ENABLED ({args.label_name})')
224222

225223

226-
def print_results(results, table_name: str, min_block: int, max_block: int,
227-
duration: float, num_workers: int, has_labels: bool, label_columns: str = ''):
224+
def print_results(
225+
results,
226+
table_name: str,
227+
min_block: int,
228+
max_block: int,
229+
duration: float,
230+
num_workers: int,
231+
has_labels: bool,
232+
label_columns: str = '',
233+
):
228234
"""Print execution results and sample queries."""
229235
# Calculate statistics
230236
total_rows = sum(r.rows_loaded for r in results if r.success)
@@ -268,131 +274,81 @@ def main():
268274
parser = argparse.ArgumentParser(
269275
description='Load data into Snowflake using parallel streaming with custom SQL queries',
270276
formatter_class=argparse.RawDescriptionHelpFormatter,
271-
epilog=__doc__
277+
epilog=__doc__,
272278
)
273279

274280
# Required arguments
275281
required = parser.add_argument_group('required arguments')
276-
required.add_argument(
277-
'--query-file',
278-
required=True,
279-
help='Path to SQL query file to execute'
280-
)
281-
required.add_argument(
282-
'--table-name',
283-
required=True,
284-
help='Destination Snowflake table name'
285-
)
282+
required.add_argument('--query-file', required=True, help='Path to SQL query file to execute')
283+
required.add_argument('--table-name', required=True, help='Destination Snowflake table name')
286284

287285
# Block range arguments (mutually exclusive groups)
288286
block_range = parser.add_argument_group('block range')
289-
block_range.add_argument(
290-
'--blocks',
291-
type=int,
292-
help='Number of recent blocks to load (auto-detect range)'
293-
)
294-
block_range.add_argument(
295-
'--min-block',
296-
type=int,
297-
help='Explicit start block (requires --max-block)'
298-
)
299-
block_range.add_argument(
300-
'--max-block',
301-
type=int,
302-
help='Explicit end block (requires --min-block)'
303-
)
287+
block_range.add_argument('--blocks', type=int, help='Number of recent blocks to load (auto-detect range)')
288+
block_range.add_argument('--min-block', type=int, help='Explicit start block (requires --max-block)')
289+
block_range.add_argument('--max-block', type=int, help='Explicit end block (requires --min-block)')
304290
block_range.add_argument(
305291
'--source-table',
306292
default='eth_firehose.logs',
307-
help='Table for block range detection (default: eth_firehose.logs)'
293+
help='Table for block range detection (default: eth_firehose.logs)',
308294
)
309295
block_range.add_argument(
310-
'--block-column',
311-
default='block_num',
312-
help='Column name for block partitioning (default: block_num)'
296+
'--block-column', default='block_num', help='Column name for block partitioning (default: block_num)'
313297
)
314298

315299
# Label configuration (all optional)
316300
labels = parser.add_argument_group('label configuration (optional)')
317-
labels.add_argument(
318-
'--label-csv',
319-
help='Path to CSV file with label data'
320-
)
321-
labels.add_argument(
322-
'--label-name',
323-
help='Label identifier (required if --label-csv provided)'
324-
)
325-
labels.add_argument(
326-
'--label-key',
327-
help='CSV column for joining (required if --label-csv provided)'
328-
)
329-
labels.add_argument(
330-
'--stream-key',
331-
help='Stream column for joining (required if --label-csv provided)'
332-
)
301+
labels.add_argument('--label-csv', help='Path to CSV file with label data')
302+
labels.add_argument('--label-name', help='Label identifier (required if --label-csv provided)')
303+
labels.add_argument('--label-key', help='CSV column for joining (required if --label-csv provided)')
304+
labels.add_argument('--stream-key', help='Stream column for joining (required if --label-csv provided)')
333305

334306
# Snowflake configuration
335307
snowflake = parser.add_argument_group('snowflake configuration')
336308
snowflake.add_argument(
337-
'--connection-name',
338-
help='Snowflake connection name (default: auto-generated from table name)'
309+
'--connection-name', help='Snowflake connection name (default: auto-generated from table name)'
339310
)
340311
snowflake.add_argument(
341312
'--loading-method',
342313
choices=['snowpipe_streaming', 'stage', 'insert'],
343314
default='snowpipe_streaming',
344-
help='Snowflake loading method (default: snowpipe_streaming)'
315+
help='Snowflake loading method (default: snowpipe_streaming)',
345316
)
346317
snowflake.add_argument(
347318
'--preserve-reorg-history',
348319
action='store_true',
349320
default=True,
350-
help='Enable reorg history preservation (default: enabled)'
321+
help='Enable reorg history preservation (default: enabled)',
351322
)
352323
snowflake.add_argument(
353324
'--no-preserve-reorg-history',
354325
action='store_false',
355326
dest='preserve_reorg_history',
356-
help='Disable reorg history preservation'
357-
)
358-
snowflake.add_argument(
359-
'--disable-state',
360-
action='store_true',
361-
help='Disable state management (job resumption)'
362-
)
363-
snowflake.add_argument(
364-
'--pool-size',
365-
type=int,
366-
help='Connection pool size (default: workers + 2)'
327+
help='Disable reorg history preservation',
367328
)
329+
snowflake.add_argument('--disable-state', action='store_true', help='Disable state management (job resumption)')
330+
snowflake.add_argument('--pool-size', type=int, help='Connection pool size (default: workers + 2)')
368331

369332
# Parallel execution configuration
370333
parallel = parser.add_argument_group('parallel execution')
371-
parallel.add_argument(
372-
'--workers',
373-
type=int,
374-
default=4,
375-
help='Number of parallel workers (default: 4)'
376-
)
334+
parallel.add_argument('--workers', type=int, default=4, help='Number of parallel workers (default: 4)')
377335
parallel.add_argument(
378336
'--flush-interval',
379337
type=float,
380338
default=1.0,
381-
help='Snowpipe Streaming buffer flush interval in seconds (default: 1.0)'
339+
help='Snowpipe Streaming buffer flush interval in seconds (default: 1.0)',
382340
)
383341

384342
# Server configuration
385343
parser.add_argument(
386344
'--server',
387345
default=os.getenv('AMP_SERVER_URL', 'grpc://34.27.238.174:80'),
388-
help='AMP server URL (default: from AMP_SERVER_URL env or grpc://34.27.238.174:80)'
346+
help='AMP server URL (default: from AMP_SERVER_URL env or grpc://34.27.238.174:80)',
389347
)
390348

391349
# Logging configuration
392350
parser.add_argument(
393-
'--verbose',
394-
action='store_true',
395-
help='Enable verbose logging from Snowflake libraries (default: suppressed)'
351+
'--verbose', action='store_true', help='Enable verbose logging from Snowflake libraries (default: suppressed)'
396352
)
397353

398354
args = parser.parse_args()
@@ -445,8 +401,7 @@ def main():
445401

446402
# Print results
447403
label_columns = f'{args.label_key} joined columns' if has_labels else ''
448-
print_results(results, args.table_name, min_block, max_block, duration,
449-
args.workers, has_labels, label_columns)
404+
print_results(results, args.table_name, min_block, max_block, duration, args.workers, has_labels, label_columns)
450405

451406
return args.table_name, sum(r.rows_loaded for r in results if r.success), duration
452407

@@ -456,6 +411,7 @@ def main():
456411
except Exception as e:
457412
print(f'\n\n❌ Error: {e}')
458413
import traceback
414+
459415
traceback.print_exc()
460416
sys.exit(1)
461417

apps/test_erc20_labeled_parallel.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,7 @@ def get_recent_block_range(client: Client, num_blocks: int = 100_000):
5555
return min_block, max_block
5656

5757

58-
def load_erc20_transfers_with_labels(
59-
num_blocks: int = 100_000, num_workers: int = 4, flush_interval: float = 1.0
60-
):
58+
def load_erc20_transfers_with_labels(num_blocks: int = 100_000, num_workers: int = 4, flush_interval: float = 1.0):
6159
"""Load ERC20 transfers with token labels using Snowpipe Streaming and parallel streaming."""
6260

6361
# Initialize client

src/amp/client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def load(
3434
destination: str,
3535
config: Dict[str, Any] = None,
3636
label_config: Optional[LabelJoinConfig] = None,
37-
**kwargs
37+
**kwargs,
3838
) -> Union[LoadResult, Iterator[LoadResult]]:
3939
"""
4040
Load query results to specified destination

src/amp/loaders/base.py

Lines changed: 24 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import time
77
from abc import ABC, abstractmethod
88
from dataclasses import fields, is_dataclass
9-
from datetime import UTC, datetime
109
from logging import Logger
1110
from typing import Any, Dict, Generic, Iterator, List, Optional, Set, TypeVar
1211

@@ -261,10 +260,7 @@ def _try_load_batch(self, batch: pa.RecordBatch, table_name: str, **kwargs) -> L
261260
if label_config:
262261
# Perform the join
263262
batch = self._join_with_labels(
264-
batch,
265-
label_config.label_name,
266-
label_config.label_key_column,
267-
label_config.stream_key_column
263+
batch, label_config.label_name, label_config.label_key_column, label_config.stream_key_column
268264
)
269265
self.logger.debug(
270266
f'Joined batch with label {label_config.label_name}: {batch.num_rows} rows after join '
@@ -478,9 +474,7 @@ def load_stream_continuous(
478474

479475
# Choose processing strategy: transactional vs non-transactional
480476
use_transactional = (
481-
hasattr(self, 'load_batch_transactional')
482-
and self.state_enabled
483-
and response.metadata.ranges
477+
hasattr(self, 'load_batch_transactional') and self.state_enabled and response.metadata.ranges
484478
)
485479

486480
if use_transactional:
@@ -636,9 +630,7 @@ def _process_batch_transactional(
636630
try:
637631
# Delegate to loader-specific transactional implementation
638632
# Loaders that support transactions implement load_batch_transactional()
639-
rows_loaded_batch = self.load_batch_transactional(
640-
batch_data, table_name, connection_name, ranges
641-
)
633+
rows_loaded_batch = self.load_batch_transactional(batch_data, table_name, connection_name, ranges)
642634
duration = time.time() - start_time
643635

644636
# Mark batches as processed in state store after successful transaction
@@ -703,7 +695,9 @@ def _process_batch_non_transactional(
703695

704696
if is_duplicate:
705697
# Skip this batch - already processed
706-
self.logger.info(f'Skipping duplicate batch: {len(ranges)} ranges already processed for {table_name}')
698+
self.logger.info(
699+
f'Skipping duplicate batch: {len(ranges)} ranges already processed for {table_name}'
700+
)
707701
return LoadResult(
708702
rows_loaded=0,
709703
duration=0.0,
@@ -731,7 +725,6 @@ def _process_batch_non_transactional(
731725

732726
return result
733727

734-
735728
def _augment_streaming_result(
736729
self, result: LoadResult, batch_count: int, ranges: Optional[List[BlockRange]], ranges_complete: bool
737730
) -> LoadResult:
@@ -808,23 +801,26 @@ def _add_metadata_columns(self, data: pa.RecordBatch, block_ranges: List[BlockRa
808801
# Convert BlockRanges to BatchIdentifiers and get compact unique IDs
809802
batch_ids = [BatchIdentifier.from_block_range(br) for br in block_ranges]
810803
# Combine multiple batch IDs with "|" separator for multi-network batches
811-
batch_id_str = "|".join(bid.unique_id for bid in batch_ids)
804+
batch_id_str = '|'.join(bid.unique_id for bid in batch_ids)
812805
batch_id_array = pa.array([batch_id_str] * num_rows, type=pa.string())
813806
result = result.append_column('_amp_batch_id', batch_id_array)
814807

815808
# Optionally add full JSON for debugging/auditing
816809
if self.store_full_metadata:
817810
import json
818-
ranges_json = json.dumps([
819-
{
820-
'network': br.network,
821-
'start': br.start,
822-
'end': br.end,
823-
'hash': br.hash,
824-
'prev_hash': br.prev_hash
825-
}
826-
for br in block_ranges
827-
])
811+
812+
ranges_json = json.dumps(
813+
[
814+
{
815+
'network': br.network,
816+
'start': br.start,
817+
'end': br.end,
818+
'hash': br.hash,
819+
'prev_hash': br.prev_hash,
820+
}
821+
for br in block_ranges
822+
]
823+
)
828824
ranges_array = pa.array([ranges_json] * num_rows, type=pa.string())
829825
result = result.append_column('_amp_block_ranges', ranges_array)
830826

@@ -966,7 +962,6 @@ def _join_with_labels(
966962

967963
# If types don't match, cast one to match the other
968964
# Prefer casting to binary since that's more efficient
969-
import pyarrow.compute as pc
970965

971966
type_conversion_time_ms = 0.0
972967
if stream_key_type != label_key_type:
@@ -1032,14 +1027,14 @@ def hex_to_binary(value):
10321027
timing_msg = (
10331028
f'⏱️ Label join: {input_rows}{output_rows} rows in {total_time_ms:.2f}ms '
10341029
f'(type_conv={type_conversion_time_ms:.2f}ms, join={join_time_ms:.2f}ms, '
1035-
f'{output_rows/total_time_ms*1000:.0f} rows/sec) '
1036-
f'[label={label_name}, retained={output_rows/input_rows*100:.1f}%]\n'
1030+
f'{output_rows / total_time_ms * 1000:.0f} rows/sec) '
1031+
f'[label={label_name}, retained={output_rows / input_rows * 100:.1f}%]\n'
10371032
)
10381033
else:
10391034
timing_msg = (
10401035
f'⏱️ Label join: {input_rows}{output_rows} rows in {total_time_ms:.2f}ms '
1041-
f'(join={join_time_ms:.2f}ms, {output_rows/total_time_ms*1000:.0f} rows/sec) '
1042-
f'[label={label_name}, retained={output_rows/input_rows*100:.1f}%]\n'
1036+
f'(join={join_time_ms:.2f}ms, {output_rows / total_time_ms * 1000:.0f} rows/sec) '
1037+
f'[label={label_name}, retained={output_rows / input_rows * 100:.1f}%]\n'
10431038
)
10441039

10451040
sys.stderr.write(timing_msg)

src/amp/loaders/implementations/deltalake_loader.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
# src/amp/loaders/implementations/deltalake_loader.py
22

3-
import json
43
import os
54
import time
65
from dataclasses import dataclass, field

src/amp/loaders/implementations/lmdb_loader.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# amp/loaders/implementations/lmdb_loader.py
22

33
import hashlib
4-
import json
54
from dataclasses import dataclass
65
from pathlib import Path
76
from typing import Any, Dict, List, Optional

0 commit comments

Comments
 (0)