@@ -50,11 +50,12 @@ class DataLoader(ABC, Generic[TConfig]):
5050 REQUIRES_SCHEMA_MATCH : bool = True
5151 SUPPORTS_TRANSACTIONS : bool = False
5252
53- def __init__ (self , config : Dict [str , Any ]) -> None :
53+ def __init__ (self , config : Dict [str , Any ], label_manager = None ) -> None :
5454 self .logger : Logger = logging .getLogger (f'{ self .__class__ .__name__ } ' )
5555 self ._connection : Optional [Any ] = None
5656 self ._is_connected : bool = False
5757 self ._created_tables : Set [str ] = set () # Track created tables
58+ self .label_manager = label_manager # For CSV label joining
5859
5960 # Parse configuration into typed format
6061 self .config : TConfig = self ._parse_config (config )
@@ -240,6 +241,7 @@ def _try_load_batch(self, batch: pa.RecordBatch, table_name: str, **kwargs) -> L
240241 This is called by load_batch() within the retry loop. It handles:
241242 - Connection management
242243 - Mode validation
244+ - Label joining (if configured)
243245 - Table creation
244246 - Error handling and timing
245247 - Metadata generation
@@ -258,7 +260,26 @@ def _try_load_batch(self, batch: pa.RecordBatch, table_name: str, **kwargs) -> L
258260 if mode not in self .SUPPORTED_MODES :
259261 raise ValueError (f'Unsupported mode { mode } . Supported modes: { self .SUPPORTED_MODES } ' )
260262
261- # Handle table creation
263+ # Apply label joining if requested
264+ label_name = kwargs .get ('label' )
265+ label_key_column = kwargs .get ('label_key_column' )
266+ stream_key_column = kwargs .get ('stream_key_column' )
267+
268+ if label_name or label_key_column or stream_key_column :
269+ # If any label param is provided, all must be provided
270+ if not (label_name and label_key_column and stream_key_column ):
271+ raise ValueError (
272+ 'Label joining requires all three parameters: label, label_key_column, stream_key_column'
273+ )
274+
275+ # Perform the join
276+ batch = self ._join_with_labels (batch , label_name , label_key_column , stream_key_column )
277+ self .logger .debug (
278+ f'Joined batch with label { label_name } : { batch .num_rows } rows after join '
279+ f'(columns: { ", " .join (batch .schema .names )} )'
280+ )
281+
282+ # Handle table creation (use joined schema if applicable)
262283 if kwargs .get ('create_table' , True ) and table_name not in self ._created_tables :
263284 if hasattr (self , '_create_table_from_schema' ):
264285 self ._create_table_from_schema (batch .schema , table_name )
@@ -891,6 +912,133 @@ def _get_loader_table_metadata(
891912 """Override in subclasses to add loader-specific table metadata"""
892913 return {}
893914
915+ def _get_effective_schema (
916+ self , original_schema : pa .Schema , label_name : Optional [str ], label_key_column : Optional [str ]
917+ ) -> pa .Schema :
918+ """
919+ Get effective schema by merging label columns into original schema.
920+
921+ If label_name is None, returns original schema unchanged.
922+ Otherwise, merges label columns (excluding the join key which is already in original).
923+
924+ Args:
925+ original_schema: Original data schema
926+ label_name: Name of the label dataset (None if no labels)
927+ label_key_column: Column name in the label table to join on
928+
929+ Returns:
930+ Schema with label columns merged in
931+ """
932+ if label_name is None or label_key_column is None :
933+ return original_schema
934+
935+ if self .label_manager is None :
936+ raise ValueError ('Label manager not configured' )
937+
938+ label_table = self .label_manager .get_label (label_name )
939+ if label_table is None :
940+ raise ValueError (f"Label '{ label_name } ' not found" )
941+
942+ # Start with original schema fields
943+ merged_fields = list (original_schema )
944+
945+ # Add label columns (excluding the join key which is already in original)
946+ for field in label_table .schema :
947+ if field .name != label_key_column and field .name not in original_schema .names :
948+ merged_fields .append (field )
949+
950+ return pa .schema (merged_fields )
951+
952+ def _join_with_labels (
953+ self , batch : pa .RecordBatch , label_name : str , label_key_column : str , stream_key_column : str
954+ ) -> pa .RecordBatch :
955+ """
956+ Join batch data with labels using inner join.
957+
958+ Handles automatic type conversion between stream and label key columns
959+ (e.g., string ↔ binary for Ethereum addresses).
960+
961+ Args:
962+ batch: Original data batch
963+ label_name: Name of the label dataset
964+ label_key_column: Column name in the label table to join on
965+ stream_key_column: Column name in the batch data to join on
966+
967+ Returns:
968+ Joined RecordBatch with label columns added
969+
970+ Raises:
971+ ValueError: If label_manager not configured, label not found, or invalid columns
972+ """
973+ if self .label_manager is None :
974+ raise ValueError ('Label manager not configured' )
975+
976+ label_table = self .label_manager .get_label (label_name )
977+ if label_table is None :
978+ raise ValueError (f"Label '{ label_name } ' not found" )
979+
980+ # Validate columns exist
981+ if stream_key_column not in batch .schema .names :
982+ raise ValueError (f"Stream key column '{ stream_key_column } ' not found in batch schema" )
983+
984+ if label_key_column not in label_table .schema .names :
985+ raise ValueError (f"Label key column '{ label_key_column } ' not found in label table" )
986+
987+ # Convert batch to table for join operation
988+ batch_table = pa .Table .from_batches ([batch ])
989+
990+ # Get column types for join keys
991+ stream_key_type = batch_table .schema .field (stream_key_column ).type
992+ label_key_type = label_table .schema .field (label_key_column ).type
993+
994+ # If types don't match, cast one to match the other
995+ # Prefer casting to binary since that's more efficient
996+ import pyarrow .compute as pc
997+
998+ if stream_key_type != label_key_type :
999+ # Try to cast stream key to label key type
1000+ if pa .types .is_fixed_size_binary (label_key_type ) and pa .types .is_string (stream_key_type ):
1001+ # Cast string to binary (hex strings like "0xABCD...")
1002+ def hex_to_binary (value ):
1003+ if value is None :
1004+ return None
1005+ # Remove 0x prefix if present
1006+ hex_str = value [2 :] if value .startswith ('0x' ) else value
1007+ return bytes .fromhex (hex_str )
1008+
1009+ # Cast the stream column to binary
1010+ stream_column = batch_table .column (stream_key_column )
1011+ binary_length = label_key_type .byte_width
1012+ binary_values = pa .array (
1013+ [hex_to_binary (v .as_py ()) for v in stream_column ], type = pa .binary (binary_length )
1014+ )
1015+ batch_table = batch_table .set_column (
1016+ batch_table .schema .get_field_index (stream_key_column ), stream_key_column , binary_values
1017+ )
1018+ elif pa .types .is_binary (stream_key_type ) and pa .types .is_string (label_key_type ):
1019+ # Cast binary to string (for test compatibility)
1020+ stream_column = batch_table .column (stream_key_column )
1021+ string_values = pa .array ([v .as_py ().hex () if v .as_py () else None for v in stream_column ])
1022+ batch_table = batch_table .set_column (
1023+ batch_table .schema .get_field_index (stream_key_column ), stream_key_column , string_values
1024+ )
1025+
1026+ # Perform inner join using PyArrow compute
1027+ # Inner join will filter out rows where stream key doesn't match any label key
1028+ joined_table = batch_table .join (
1029+ label_table , keys = stream_key_column , right_keys = label_key_column , join_type = 'inner'
1030+ )
1031+
1032+ # Convert back to RecordBatch
1033+ if joined_table .num_rows == 0 :
1034+ # Empty result - return empty batch with joined schema
1035+ # Need to create empty arrays for each column
1036+ empty_data = {field .name : pa .array ([], type = field .type ) for field in joined_table .schema }
1037+ return pa .RecordBatch .from_pydict (empty_data , schema = joined_table .schema )
1038+
1039+ # Return as a single batch (assuming batch sizes are manageable)
1040+ return joined_table .to_batches ()[0 ]
1041+
8941042 def __enter__ (self ) -> 'DataLoader' :
8951043 self .connect ()
8961044 return self
0 commit comments