|
| 1 | +"""Tests for label joining functionality in base DataLoader""" |
| 2 | + |
| 3 | +import tempfile |
| 4 | +from pathlib import Path |
| 5 | +from typing import Any, Dict |
| 6 | + |
| 7 | +import pyarrow as pa |
| 8 | +import pytest |
| 9 | + |
| 10 | +from amp.config.label_manager import LabelManager |
| 11 | +from amp.loaders.base import DataLoader |
| 12 | + |
| 13 | + |
| 14 | +class MockLoader(DataLoader): |
| 15 | + """Mock loader for testing""" |
| 16 | + |
| 17 | + def __init__(self, config: Dict[str, Any], label_manager=None): |
| 18 | + super().__init__(config, label_manager=label_manager) |
| 19 | + |
| 20 | + def _parse_config(self, config: Dict[str, Any]) -> Dict[str, Any]: |
| 21 | + """Override to just return the dict without parsing""" |
| 22 | + return config |
| 23 | + |
| 24 | + def connect(self) -> None: |
| 25 | + self._is_connected = True |
| 26 | + |
| 27 | + def disconnect(self) -> None: |
| 28 | + self._is_connected = False |
| 29 | + |
| 30 | + def _load_batch_impl(self, batch: pa.RecordBatch, table_name: str, **kwargs) -> int: |
| 31 | + return batch.num_rows |
| 32 | + |
| 33 | + |
| 34 | +class TestLabelJoining: |
| 35 | + """Test label joining functionality""" |
| 36 | + |
| 37 | + @pytest.fixture |
| 38 | + def label_manager(self): |
| 39 | + """Create a label manager with test data""" |
| 40 | + # Create a temporary CSV file with token labels (valid 40-char Ethereum addresses) |
| 41 | + with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False) as f: |
| 42 | + f.write('address,symbol,decimals\n') |
| 43 | + f.write('0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa,USDC,6\n') |
| 44 | + f.write('0xbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb,WETH,18\n') |
| 45 | + f.write('0xcccccccccccccccccccccccccccccccccccccccc,DAI,18\n') |
| 46 | + csv_path = f.name |
| 47 | + |
| 48 | + try: |
| 49 | + manager = LabelManager() |
| 50 | + manager.add_label('tokens', csv_path) |
| 51 | + yield manager |
| 52 | + finally: |
| 53 | + Path(csv_path).unlink() |
| 54 | + |
| 55 | + def test_get_effective_schema(self, label_manager): |
| 56 | + """Test schema merging with label columns""" |
| 57 | + loader = MockLoader({}, label_manager=label_manager) |
| 58 | + |
| 59 | + # Original schema |
| 60 | + original_schema = pa.schema([('address', pa.string()), ('amount', pa.int64())]) |
| 61 | + |
| 62 | + # Get effective schema with labels |
| 63 | + effective_schema = loader._get_effective_schema(original_schema, 'tokens', 'address') |
| 64 | + |
| 65 | + # Should have original columns plus label columns (excluding join key) |
| 66 | + assert 'address' in effective_schema.names |
| 67 | + assert 'amount' in effective_schema.names |
| 68 | + assert 'symbol' in effective_schema.names # From label |
| 69 | + assert 'decimals' in effective_schema.names # From label |
| 70 | + |
| 71 | + # Total: 2 original + 2 label columns (join key 'address' already in original) = 4 |
| 72 | + assert len(effective_schema) == 4 |
| 73 | + |
| 74 | + def test_get_effective_schema_no_labels(self, label_manager): |
| 75 | + """Test schema without labels returns original schema""" |
| 76 | + loader = MockLoader({}, label_manager=label_manager) |
| 77 | + |
| 78 | + original_schema = pa.schema([('address', pa.string()), ('amount', pa.int64())]) |
| 79 | + |
| 80 | + # No label specified |
| 81 | + effective_schema = loader._get_effective_schema(original_schema, None, None) |
| 82 | + |
| 83 | + assert effective_schema == original_schema |
| 84 | + |
| 85 | + def test_join_with_labels(self, label_manager): |
| 86 | + """Test joining batch data with labels""" |
| 87 | + loader = MockLoader({}, label_manager=label_manager) |
| 88 | + |
| 89 | + # Create test batch with transfers (using full 40-char addresses) |
| 90 | + batch = pa.RecordBatch.from_pydict( |
| 91 | + { |
| 92 | + 'address': [ |
| 93 | + '0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa', |
| 94 | + '0xbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb', |
| 95 | + '0xffffffffffffffffffffffffffffffffffffffff', |
| 96 | + ], # Last one doesn't exist in labels |
| 97 | + 'amount': [100, 200, 300], |
| 98 | + } |
| 99 | + ) |
| 100 | + |
| 101 | + # Join with labels (inner join should filter out 0xfff...) |
| 102 | + joined_batch = loader._join_with_labels(batch, 'tokens', 'address', 'address') |
| 103 | + |
| 104 | + # Should only have 2 rows (first two addresses, last one filtered out) |
| 105 | + assert joined_batch.num_rows == 2 |
| 106 | + |
| 107 | + # Should have original columns plus label columns |
| 108 | + assert 'address' in joined_batch.schema.names |
| 109 | + assert 'amount' in joined_batch.schema.names |
| 110 | + assert 'symbol' in joined_batch.schema.names |
| 111 | + assert 'decimals' in joined_batch.schema.names |
| 112 | + |
| 113 | + # Verify joined data - after type conversion and join, addresses should be binary |
| 114 | + joined_dict = joined_batch.to_pydict() |
| 115 | + # Convert binary back to hex for comparison |
| 116 | + addresses_hex = [addr.hex() for addr in joined_dict['address']] |
| 117 | + assert addresses_hex == ['aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa', 'bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb'] |
| 118 | + assert joined_dict['amount'] == [100, 200] |
| 119 | + assert joined_dict['symbol'] == ['USDC', 'WETH'] |
| 120 | + # Decimals are strings because we force all CSV columns to strings for type safety |
| 121 | + assert joined_dict['decimals'] == ['6', '18'] |
| 122 | + |
| 123 | + def test_join_with_all_matching_keys(self, label_manager): |
| 124 | + """Test join when all keys match""" |
| 125 | + loader = MockLoader({}, label_manager=label_manager) |
| 126 | + |
| 127 | + batch = pa.RecordBatch.from_pydict( |
| 128 | + { |
| 129 | + 'address': [ |
| 130 | + '0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa', |
| 131 | + '0xbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb', |
| 132 | + '0xcccccccccccccccccccccccccccccccccccccccc', |
| 133 | + ], |
| 134 | + 'amount': [100, 200, 300], |
| 135 | + } |
| 136 | + ) |
| 137 | + |
| 138 | + joined_batch = loader._join_with_labels(batch, 'tokens', 'address', 'address') |
| 139 | + |
| 140 | + # All 3 rows should be present |
| 141 | + assert joined_batch.num_rows == 3 |
| 142 | + |
| 143 | + def test_join_with_no_matching_keys(self, label_manager): |
| 144 | + """Test join when no keys match""" |
| 145 | + loader = MockLoader({}, label_manager=label_manager) |
| 146 | + |
| 147 | + batch = pa.RecordBatch.from_pydict( |
| 148 | + { |
| 149 | + 'address': [ |
| 150 | + '0xdddddddddddddddddddddddddddddddddddddddd', |
| 151 | + '0xeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee', |
| 152 | + '0xffffffffffffffffffffffffffffffffffffffff', |
| 153 | + ], |
| 154 | + 'amount': [100, 200, 300], |
| 155 | + } |
| 156 | + ) |
| 157 | + |
| 158 | + joined_batch = loader._join_with_labels(batch, 'tokens', 'address', 'address') |
| 159 | + |
| 160 | + # Should have 0 rows (all filtered out) |
| 161 | + assert joined_batch.num_rows == 0 |
| 162 | + |
| 163 | + def test_join_invalid_label_name(self, label_manager): |
| 164 | + """Test join with non-existent label""" |
| 165 | + loader = MockLoader({}, label_manager=label_manager) |
| 166 | + |
| 167 | + batch = pa.RecordBatch.from_pydict({'address': ['0xA'], 'amount': [100]}) |
| 168 | + |
| 169 | + with pytest.raises(ValueError, match="Label 'nonexistent' not found"): |
| 170 | + loader._join_with_labels(batch, 'nonexistent', 'address', 'address') |
| 171 | + |
| 172 | + def test_join_invalid_stream_key(self, label_manager): |
| 173 | + """Test join with invalid stream key column""" |
| 174 | + loader = MockLoader({}, label_manager=label_manager) |
| 175 | + |
| 176 | + batch = pa.RecordBatch.from_pydict({'address': ['0xA'], 'amount': [100]}) |
| 177 | + |
| 178 | + with pytest.raises(ValueError, match="Stream key column 'nonexistent' not found"): |
| 179 | + loader._join_with_labels(batch, 'tokens', 'address', 'nonexistent') |
| 180 | + |
| 181 | + def test_join_invalid_label_key(self, label_manager): |
| 182 | + """Test join with invalid label key column""" |
| 183 | + loader = MockLoader({}, label_manager=label_manager) |
| 184 | + |
| 185 | + batch = pa.RecordBatch.from_pydict({'address': ['0xA'], 'amount': [100]}) |
| 186 | + |
| 187 | + with pytest.raises(ValueError, match="Label key column 'nonexistent' not found"): |
| 188 | + loader._join_with_labels(batch, 'tokens', 'nonexistent', 'address') |
| 189 | + |
| 190 | + def test_join_no_label_manager(self): |
| 191 | + """Test join when label manager not configured""" |
| 192 | + loader = MockLoader({}, label_manager=None) |
| 193 | + |
| 194 | + batch = pa.RecordBatch.from_pydict({'address': ['0xA'], 'amount': [100]}) |
| 195 | + |
| 196 | + with pytest.raises(ValueError, match='Label manager not configured'): |
| 197 | + loader._join_with_labels(batch, 'tokens', 'address', 'address') |
0 commit comments