|
| 1 | +""" |
| 2 | +Label Manager for CSV-based label datasets. |
| 3 | +
|
| 4 | +This module provides functionality to register and manage CSV label datasets |
| 5 | +that can be joined with streaming data during loading operations. |
| 6 | +""" |
| 7 | + |
| 8 | +import logging |
| 9 | +from typing import Dict, List, Optional |
| 10 | + |
| 11 | +import pyarrow as pa |
| 12 | +import pyarrow.csv as csv |
| 13 | + |
| 14 | + |
| 15 | +class LabelManager: |
| 16 | + """ |
| 17 | + Manages CSV label datasets for joining with streaming data. |
| 18 | +
|
| 19 | + Labels are registered by name and loaded as PyArrow Tables for efficient |
| 20 | + joining operations. This allows reuse of label datasets across multiple |
| 21 | + queries and loaders. |
| 22 | +
|
| 23 | + Example: |
| 24 | + >>> manager = LabelManager() |
| 25 | + >>> manager.add_label('token_labels', '/path/to/tokens.csv') |
| 26 | + >>> label_table = manager.get_label('token_labels') |
| 27 | + """ |
| 28 | + |
| 29 | + def __init__(self): |
| 30 | + self._labels: Dict[str, pa.Table] = {} |
| 31 | + self.logger = logging.getLogger(__name__) |
| 32 | + |
| 33 | + def add_label(self, name: str, csv_path: str, binary_columns: Optional[List[str]] = None) -> None: |
| 34 | + """ |
| 35 | + Load and register a CSV label dataset with automatic hex→binary conversion. |
| 36 | +
|
| 37 | + Hex string columns (like Ethereum addresses) are automatically converted to |
| 38 | + binary format for efficient storage and joining. This reduces memory usage |
| 39 | + by ~50% and improves join performance. |
| 40 | +
|
| 41 | + Args: |
| 42 | + name: Unique name for this label dataset |
| 43 | + csv_path: Path to the CSV file |
| 44 | + binary_columns: List of column names containing hex addresses to convert to binary. |
| 45 | + If None, auto-detects columns with 'address' in the name. |
| 46 | +
|
| 47 | + Raises: |
| 48 | + FileNotFoundError: If CSV file doesn't exist |
| 49 | + ValueError: If CSV cannot be parsed or name already exists |
| 50 | + """ |
| 51 | + if name in self._labels: |
| 52 | + self.logger.warning(f"Label '{name}' already exists, replacing with new data") |
| 53 | + |
| 54 | + try: |
| 55 | + # Load CSV as PyArrow Table (initially as strings) |
| 56 | + temp_table = csv.read_csv(csv_path, read_options=csv.ReadOptions(autogenerate_column_names=False)) |
| 57 | + |
| 58 | + # Force all columns to be strings initially |
| 59 | + column_types = {col_name: pa.string() for col_name in temp_table.column_names} |
| 60 | + convert_opts = csv.ConvertOptions(column_types=column_types) |
| 61 | + label_table = csv.read_csv(csv_path, convert_options=convert_opts) |
| 62 | + |
| 63 | + # Auto-detect or use specified binary columns |
| 64 | + if binary_columns is None: |
| 65 | + # Auto-detect columns with 'address' in name (case-insensitive) |
| 66 | + binary_columns = [col for col in label_table.column_names if 'address' in col.lower()] |
| 67 | + |
| 68 | + # Convert hex string columns to binary for efficiency |
| 69 | + converted_columns = [] |
| 70 | + for col_name in binary_columns: |
| 71 | + if col_name not in label_table.column_names: |
| 72 | + self.logger.warning(f"Binary column '{col_name}' not found in CSV, skipping") |
| 73 | + continue |
| 74 | + |
| 75 | + hex_col = label_table.column(col_name) |
| 76 | + |
| 77 | + # Detect hex string format and convert to binary |
| 78 | + # Sample first non-null value to determine format |
| 79 | + sample_value = None |
| 80 | + for v in hex_col.to_pylist()[:100]: # Check first 100 values |
| 81 | + if v is not None: |
| 82 | + sample_value = v |
| 83 | + break |
| 84 | + |
| 85 | + if sample_value is None: |
| 86 | + self.logger.warning(f"Column '{col_name}' has no non-null values, skipping conversion") |
| 87 | + continue |
| 88 | + |
| 89 | + # Detect if it's a hex string (with or without 0x prefix) |
| 90 | + if isinstance(sample_value, str) and all(c in '0123456789abcdefABCDEFx' for c in sample_value): |
| 91 | + # Determine binary length from hex string |
| 92 | + hex_str = sample_value[2:] if sample_value.startswith('0x') else sample_value |
| 93 | + binary_length = len(hex_str) // 2 |
| 94 | + |
| 95 | + # Convert all values to binary (fixed-size to match streaming data) |
| 96 | + def hex_to_binary(v): |
| 97 | + if v is None: |
| 98 | + return None |
| 99 | + hex_str = v[2:] if v.startswith('0x') else v |
| 100 | + return bytes.fromhex(hex_str) |
| 101 | + |
| 102 | + binary_values = pa.array( |
| 103 | + [hex_to_binary(v) for v in hex_col.to_pylist()], |
| 104 | + type=pa.binary( |
| 105 | + binary_length |
| 106 | + ), # Fixed-size binary to match server data (e.g., 20 bytes for addresses) |
| 107 | + ) |
| 108 | + |
| 109 | + # Replace the column |
| 110 | + label_table = label_table.set_column( |
| 111 | + label_table.schema.get_field_index(col_name), col_name, binary_values |
| 112 | + ) |
| 113 | + converted_columns.append(f'{col_name} (hex→fixed_size_binary[{binary_length}])') |
| 114 | + self.logger.info(f"Converted '{col_name}' from hex string to fixed_size_binary[{binary_length}]") |
| 115 | + |
| 116 | + self._labels[name] = label_table |
| 117 | + |
| 118 | + conversion_info = f', converted: {", ".join(converted_columns)}' if converted_columns else '' |
| 119 | + self.logger.info( |
| 120 | + f"Loaded label '{name}' from {csv_path}: " |
| 121 | + f'{label_table.num_rows:,} rows, {len(label_table.schema)} columns ' |
| 122 | + f'({", ".join(label_table.schema.names)}){conversion_info}' |
| 123 | + ) |
| 124 | + |
| 125 | + except FileNotFoundError: |
| 126 | + raise FileNotFoundError(f'Label CSV file not found: {csv_path}') |
| 127 | + except Exception as e: |
| 128 | + raise ValueError(f"Failed to load label CSV '{csv_path}': {e}") from e |
| 129 | + |
| 130 | + def get_label(self, name: str) -> Optional[pa.Table]: |
| 131 | + """ |
| 132 | + Get label table by name. |
| 133 | +
|
| 134 | + Args: |
| 135 | + name: Name of the label dataset |
| 136 | +
|
| 137 | + Returns: |
| 138 | + PyArrow Table containing label data, or None if not found |
| 139 | + """ |
| 140 | + return self._labels.get(name) |
| 141 | + |
| 142 | + def list_labels(self) -> List[str]: |
| 143 | + """ |
| 144 | + List all registered label names. |
| 145 | +
|
| 146 | + Returns: |
| 147 | + List of label names |
| 148 | + """ |
| 149 | + return list(self._labels.keys()) |
| 150 | + |
| 151 | + def remove_label(self, name: str) -> bool: |
| 152 | + """ |
| 153 | + Remove a label dataset. |
| 154 | +
|
| 155 | + Args: |
| 156 | + name: Name of the label to remove |
| 157 | +
|
| 158 | + Returns: |
| 159 | + True if label was removed, False if it didn't exist |
| 160 | + """ |
| 161 | + if name in self._labels: |
| 162 | + del self._labels[name] |
| 163 | + self.logger.info(f"Removed label '{name}'") |
| 164 | + return True |
| 165 | + return False |
| 166 | + |
| 167 | + def clear(self) -> None: |
| 168 | + """Remove all label datasets.""" |
| 169 | + count = len(self._labels) |
| 170 | + self._labels.clear() |
| 171 | + self.logger.info(f'Cleared {count} label dataset(s)') |
0 commit comments