Skip to content

Commit 9fdb633

Browse files
committed
feat: Integrate label manager into Client for enriched streaming
Add label management to Client class: - Initialize LabelManager with configurable label directory - Support loading labels from CSV files - Pass label_manager to all loader instances - Enable label joining in streaming queries via load() method Updates: - Client now supports label enrichment out of the box - Loaders inherit label_manager from client - Add pyarrow.csv dependency for label loading
1 parent 7e79193 commit 9fdb633

File tree

2 files changed

+113
-22
lines changed

2 files changed

+113
-22
lines changed

pyproject.toml

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,21 +12,17 @@ dependencies = [
1212
"pandas>=2.3.1",
1313
"pyarrow>=20.0.0",
1414
"typer>=0.15.2",
15-
1615
# Flight SQL support
1716
"adbc-driver-manager>=1.5.0",
1817
"adbc-driver-postgresql>=1.5.0",
1918
"protobuf>=4.21.0",
20-
2119
# Ethereum/blockchain utilities
2220
"base58>=2.1.1",
2321
"eth-hash[pysha3]>=0.7.1",
2422
"eth-utils>=5.2.0",
25-
2623
# Google Cloud support
2724
"google-cloud-bigquery>=3.30.0",
2825
"google-cloud-storage>=3.1.0",
29-
3026
# Arro3 for enhanced PyArrow operations
3127
"arro3-core>=0.5.1",
3228
"arro3-compute>=0.5.1",
@@ -58,7 +54,8 @@ iceberg = [
5854
]
5955

6056
snowflake = [
61-
"snowflake-connector-python>=3.5.0",
57+
"snowflake-connector-python>=4.0.0",
58+
"snowpipe-streaming>=1.0.0", # Snowpipe Streaming API
6259
]
6360

6461
lmdb = [
@@ -71,7 +68,8 @@ all_loaders = [
7168
"deltalake>=1.0.2", # Delta Lake (consistent version)
7269
"pyiceberg[sql-sqlite]>=0.10.0", # Apache Iceberg
7370
"pydantic>=2.0,<2.12", # PyIceberg 0.10.0 compatibility
74-
"snowflake-connector-python>=3.5.0", # Snowflake
71+
"snowflake-connector-python>=4.0.0", # Snowflake
72+
"snowpipe-streaming>=1.0.0", # Snowpipe Streaming API
7573
"lmdb>=1.4.0", # LMDB
7674
]
7775

@@ -91,6 +89,9 @@ test = [
9189
requires = ["hatchling"]
9290
build-backend = "hatchling.build"
9391

92+
[tool.hatch.build.targets.wheel]
93+
packages = ["src/amp"]
94+
9495
[tool.pytest.ini_options]
9596
pythonpath = ["."]
9697
testpaths = ["tests"]

src/amp/client.py

Lines changed: 106 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@
77

88
from . import FlightSql_pb2
99
from .config.connection_manager import ConnectionManager
10+
from .config.label_manager import LabelManager
1011
from .loaders.registry import create_loader, get_available_loaders
11-
from .loaders.types import LoadConfig, LoadMode, LoadResult
12+
from .loaders.types import LabelJoinConfig, LoadConfig, LoadMode, LoadResult
1213
from .streaming import (
1314
ParallelConfig,
1415
ParallelStreamExecutor,
@@ -28,7 +29,12 @@ def __init__(self, client: 'Client', query: str):
2829
self.logger = logging.getLogger(__name__)
2930

3031
def load(
31-
self, connection: str, destination: str, config: Dict[str, Any] = None, **kwargs
32+
self,
33+
connection: str,
34+
destination: str,
35+
config: Dict[str, Any] = None,
36+
label_config: Optional[LabelJoinConfig] = None,
37+
**kwargs
3238
) -> Union[LoadResult, Iterator[LoadResult]]:
3339
"""
3440
Load query results to specified destination
@@ -38,12 +44,16 @@ def load(
3844
destination: Target destination (table name, key, path, etc.)
3945
connection: Named connection or connection name for auto-discovery
4046
config: Inline configuration dict (alternative to connection)
47+
label_config: Optional LabelJoinConfig for joining with label data
4148
**kwargs: Additional loader-specific options including:
4249
- read_all: bool = False (if True, loads entire table at once; if False, streams batch by batch)
4350
- batch_size: int = 10000 (size of each batch for streaming)
4451
- stream: bool = False (if True, enables continuous streaming with reorg detection)
4552
- with_reorg_detection: bool = True (enable reorg detection for streaming queries)
4653
- resume_watermark: Optional[ResumeWatermark] = None (resume streaming from specific point)
54+
- label: str (deprecated, use label_config instead)
55+
- label_key_column: str (deprecated, use label_config instead)
56+
- stream_key_column: str (deprecated, use label_config instead)
4757
4858
Returns:
4959
- If read_all=True: Single LoadResult with operation details
@@ -58,7 +68,12 @@ def load(
5868
# TODO: Add validation that the specific query uses features supported by streaming
5969
streaming_query = self._ensure_streaming_query(self.query)
6070
return self.client.query_and_load_streaming(
61-
query=streaming_query, destination=destination, connection_name=connection, config=config, **kwargs
71+
query=streaming_query,
72+
destination=destination,
73+
connection_name=connection,
74+
config=config,
75+
label_config=label_config,
76+
**kwargs,
6277
)
6378

6479
# Validate that parallel_config is only used with stream=True
@@ -69,7 +84,12 @@ def load(
6984
kwargs.setdefault('read_all', False)
7085

7186
return self.client.query_and_load(
72-
query=self.query, destination=destination, connection_name=connection, config=config, **kwargs
87+
query=self.query,
88+
destination=destination,
89+
connection_name=connection,
90+
config=config,
91+
label_config=label_config,
92+
**kwargs,
7393
)
7494

7595
def _ensure_streaming_query(self, query: str) -> str:
@@ -105,6 +125,7 @@ class Client:
105125
def __init__(self, url):
106126
self.conn = flight.connect(url)
107127
self.connection_manager = ConnectionManager()
128+
self.label_manager = LabelManager()
108129
self.logger = logging.getLogger(__name__)
109130

110131
def sql(self, query: str) -> QueryBuilder:
@@ -123,6 +144,18 @@ def configure_connection(self, name: str, loader: str, config: Dict[str, Any]) -
123144
"""Configure a named connection for reuse"""
124145
self.connection_manager.add_connection(name, loader, config)
125146

147+
def configure_label(self, name: str, csv_path: str, binary_columns: Optional[List[str]] = None) -> None:
148+
"""
149+
Configure a label dataset from a CSV file for joining with streaming data.
150+
151+
Args:
152+
name: Unique name for this label dataset
153+
csv_path: Path to the CSV file
154+
binary_columns: List of column names containing hex addresses to convert to binary.
155+
If None, auto-detects columns with 'address' in the name.
156+
"""
157+
self.label_manager.add_label(name, csv_path, binary_columns)
158+
126159
def list_connections(self) -> Dict[str, str]:
127160
"""List all configured connections"""
128161
return self.connection_manager.list_connections()
@@ -162,7 +195,13 @@ def _batch_generator(self, reader):
162195
break
163196

164197
def query_and_load(
165-
self, query: str, destination: str, connection_name: str, config: Optional[Dict[str, Any]] = None, **kwargs
198+
self,
199+
query: str,
200+
destination: str,
201+
connection_name: str,
202+
config: Optional[Dict[str, Any]] = None,
203+
label_config: Optional[LabelJoinConfig] = None,
204+
**kwargs,
166205
) -> Union[LoadResult, Iterator[LoadResult]]:
167206
"""
168207
Execute query and load results directly into target system
@@ -211,6 +250,13 @@ def query_and_load(
211250
**{k: v for k, v in kwargs.items() if k in ['max_retries', 'retry_delay']},
212251
)
213252

253+
# Remove known LoadConfig params from kwargs, leaving loader-specific params
254+
for key in ['max_retries', 'retry_delay']:
255+
kwargs.pop(key, None)
256+
257+
# Remaining kwargs are loader-specific (e.g., channel_suffix for Snowflake)
258+
loader_specific_kwargs = kwargs
259+
214260
if read_all:
215261
self.logger.info(f'Loading entire query result to {loader_type}:{destination}')
216262
else:
@@ -221,20 +267,36 @@ def query_and_load(
221267
# Get the data and load
222268
if read_all:
223269
table = self.get_sql(query, read_all=True)
224-
return self._load_table(table, loader_type, destination, loader_config, load_config)
270+
return self._load_table(
271+
table,
272+
loader_type,
273+
destination,
274+
loader_config,
275+
load_config,
276+
label_config=label_config,
277+
**loader_specific_kwargs,
278+
)
225279
else:
226280
batch_stream = self.get_sql(query, read_all=False)
227-
return self._load_stream(batch_stream, loader_type, destination, loader_config, load_config)
281+
return self._load_stream(
282+
batch_stream,
283+
loader_type,
284+
destination,
285+
loader_config,
286+
load_config,
287+
label_config=label_config,
288+
**loader_specific_kwargs,
289+
)
228290

229291
def _load_table(
230-
self, table: pa.Table, loader: str, table_name: str, config: Dict[str, Any], load_config: LoadConfig
292+
self, table: pa.Table, loader: str, table_name: str, config: Dict[str, Any], load_config: LoadConfig, **kwargs
231293
) -> LoadResult:
232294
"""Load a complete Arrow Table"""
233295
try:
234-
loader_instance = create_loader(loader, config)
296+
loader_instance = create_loader(loader, config, label_manager=self.label_manager)
235297

236298
with loader_instance:
237-
return loader_instance.load_table(table, table_name, **load_config.__dict__)
299+
return loader_instance.load_table(table, table_name, **load_config.__dict__, **kwargs)
238300
except Exception as e:
239301
self.logger.error(f'Failed to load table: {e}')
240302
return LoadResult(
@@ -254,13 +316,14 @@ def _load_stream(
254316
table_name: str,
255317
config: Dict[str, Any],
256318
load_config: LoadConfig,
319+
**kwargs,
257320
) -> Iterator[LoadResult]:
258321
"""Load from a stream of batches"""
259322
try:
260-
loader_instance = create_loader(loader, config)
323+
loader_instance = create_loader(loader, config, label_manager=self.label_manager)
261324

262325
with loader_instance:
263-
yield from loader_instance.load_stream(batch_stream, table_name, **load_config.__dict__)
326+
yield from loader_instance.load_stream(batch_stream, table_name, **load_config.__dict__, **kwargs)
264327
except Exception as e:
265328
self.logger.error(f'Failed to load stream: {e}')
266329
yield LoadResult(
@@ -279,6 +342,7 @@ def query_and_load_streaming(
279342
destination: str,
280343
connection_name: str,
281344
config: Optional[Dict[str, Any]] = None,
345+
label_config: Optional[LabelJoinConfig] = None,
282346
with_reorg_detection: bool = True,
283347
resume_watermark: Optional[ResumeWatermark] = None,
284348
parallel_config: Optional[ParallelConfig] = None,
@@ -315,6 +379,10 @@ def query_and_load_streaming(
315379
**{k: v for k, v in kwargs.items() if k in ['max_retries', 'retry_delay']},
316380
}
317381

382+
# Add label_config if provided
383+
if label_config:
384+
load_config_dict['label_config'] = label_config
385+
318386
yield from executor.execute_parallel_stream(query, destination, connection_name, load_config_dict)
319387
return
320388

@@ -346,6 +414,27 @@ def query_and_load_streaming(
346414

347415
self.logger.info(f'Starting streaming query to {loader_type}:{destination}')
348416

417+
# Create loader instance early to access checkpoint store
418+
loader_instance = create_loader(loader_type, loader_config, label_manager=self.label_manager)
419+
420+
# Load checkpoint and create resume watermark if enabled (default: enabled)
421+
if resume_watermark is None and kwargs.get('resume', True):
422+
try:
423+
checkpoint = loader_instance.checkpoint_store.load(connection_name, destination)
424+
425+
if checkpoint:
426+
resume_watermark = checkpoint.to_resume_watermark()
427+
checkpoint_type = 'reorg checkpoint' if checkpoint.is_reorg else 'checkpoint'
428+
self.logger.info(
429+
f'Resuming from {checkpoint_type}: {len(checkpoint.ranges)} ranges, '
430+
f'timestamp {checkpoint.timestamp}'
431+
)
432+
if checkpoint.is_reorg:
433+
resume_points = ', '.join(f'{r.network}:{r.start}' for r in checkpoint.ranges)
434+
self.logger.info(f'Reorg resume points: {resume_points}')
435+
except Exception as e:
436+
self.logger.warning(f'Failed to load checkpoint, starting from beginning: {e}')
437+
349438
try:
350439
# Execute streaming query with Flight SQL
351440
# Create a CommandStatementQuery message
@@ -376,12 +465,13 @@ def query_and_load_streaming(
376465
stream_iterator = ReorgAwareStream(stream_iterator)
377466
self.logger.info('Reorg detection enabled for streaming query')
378467

379-
# Create loader instance and start continuous loading
380-
loader_instance = create_loader(loader_type, loader_config)
381-
468+
# Start continuous loading with checkpoint support
382469
with loader_instance:
383470
self.logger.info(f'Starting continuous load to {destination}. Press Ctrl+C to stop.')
384-
yield from loader_instance.load_stream_continuous(stream_iterator, destination, **load_config.__dict__)
471+
# Pass connection_name for checkpoint saving
472+
yield from loader_instance.load_stream_continuous(
473+
stream_iterator, destination, connection_name=connection_name, **load_config.__dict__
474+
)
385475

386476
except Exception as e:
387477
self.logger.error(f'Streaming query failed: {e}')

0 commit comments

Comments
 (0)