Skip to content

Commit a121d3e

Browse files
committed
label manager tests
1 parent d25888c commit a121d3e

File tree

2 files changed

+349
-0
lines changed

2 files changed

+349
-0
lines changed

tests/unit/test_label_joining.py

Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
"""Tests for label joining functionality in base DataLoader"""
2+
3+
import tempfile
4+
from pathlib import Path
5+
from typing import Any, Dict
6+
7+
import pyarrow as pa
8+
import pytest
9+
10+
from amp.config.label_manager import LabelManager
11+
from amp.loaders.base import DataLoader
12+
13+
14+
class MockLoader(DataLoader):
15+
"""Mock loader for testing"""
16+
17+
def __init__(self, config: Dict[str, Any], label_manager=None):
18+
super().__init__(config, label_manager=label_manager)
19+
20+
def _parse_config(self, config: Dict[str, Any]) -> Dict[str, Any]:
21+
"""Override to just return the dict without parsing"""
22+
return config
23+
24+
def connect(self) -> None:
25+
self._is_connected = True
26+
27+
def disconnect(self) -> None:
28+
self._is_connected = False
29+
30+
def _load_batch_impl(self, batch: pa.RecordBatch, table_name: str, **kwargs) -> int:
31+
return batch.num_rows
32+
33+
34+
class TestLabelJoining:
35+
"""Test label joining functionality"""
36+
37+
@pytest.fixture
38+
def label_manager(self):
39+
"""Create a label manager with test data"""
40+
# Create a temporary CSV file with token labels (valid 40-char Ethereum addresses)
41+
with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False) as f:
42+
f.write('address,symbol,decimals\n')
43+
f.write('0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa,USDC,6\n')
44+
f.write('0xbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb,WETH,18\n')
45+
f.write('0xcccccccccccccccccccccccccccccccccccccccc,DAI,18\n')
46+
csv_path = f.name
47+
48+
try:
49+
manager = LabelManager()
50+
manager.add_label('tokens', csv_path)
51+
yield manager
52+
finally:
53+
Path(csv_path).unlink()
54+
55+
def test_get_effective_schema(self, label_manager):
56+
"""Test schema merging with label columns"""
57+
loader = MockLoader({}, label_manager=label_manager)
58+
59+
# Original schema
60+
original_schema = pa.schema([('address', pa.string()), ('amount', pa.int64())])
61+
62+
# Get effective schema with labels
63+
effective_schema = loader._get_effective_schema(original_schema, 'tokens', 'address')
64+
65+
# Should have original columns plus label columns (excluding join key)
66+
assert 'address' in effective_schema.names
67+
assert 'amount' in effective_schema.names
68+
assert 'symbol' in effective_schema.names # From label
69+
assert 'decimals' in effective_schema.names # From label
70+
71+
# Total: 2 original + 2 label columns (join key 'address' already in original) = 4
72+
assert len(effective_schema) == 4
73+
74+
def test_get_effective_schema_no_labels(self, label_manager):
75+
"""Test schema without labels returns original schema"""
76+
loader = MockLoader({}, label_manager=label_manager)
77+
78+
original_schema = pa.schema([('address', pa.string()), ('amount', pa.int64())])
79+
80+
# No label specified
81+
effective_schema = loader._get_effective_schema(original_schema, None, None)
82+
83+
assert effective_schema == original_schema
84+
85+
def test_join_with_labels(self, label_manager):
86+
"""Test joining batch data with labels"""
87+
loader = MockLoader({}, label_manager=label_manager)
88+
89+
# Create test batch with transfers (using full 40-char addresses)
90+
batch = pa.RecordBatch.from_pydict(
91+
{
92+
'address': [
93+
'0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa',
94+
'0xbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb',
95+
'0xffffffffffffffffffffffffffffffffffffffff',
96+
], # Last one doesn't exist in labels
97+
'amount': [100, 200, 300],
98+
}
99+
)
100+
101+
# Join with labels (inner join should filter out 0xfff...)
102+
joined_batch = loader._join_with_labels(batch, 'tokens', 'address', 'address')
103+
104+
# Should only have 2 rows (first two addresses, last one filtered out)
105+
assert joined_batch.num_rows == 2
106+
107+
# Should have original columns plus label columns
108+
assert 'address' in joined_batch.schema.names
109+
assert 'amount' in joined_batch.schema.names
110+
assert 'symbol' in joined_batch.schema.names
111+
assert 'decimals' in joined_batch.schema.names
112+
113+
# Verify joined data - after type conversion and join, addresses should be binary
114+
joined_dict = joined_batch.to_pydict()
115+
# Convert binary back to hex for comparison
116+
addresses_hex = [addr.hex() for addr in joined_dict['address']]
117+
assert addresses_hex == ['aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa', 'bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb']
118+
assert joined_dict['amount'] == [100, 200]
119+
assert joined_dict['symbol'] == ['USDC', 'WETH']
120+
# Decimals are strings because we force all CSV columns to strings for type safety
121+
assert joined_dict['decimals'] == ['6', '18']
122+
123+
def test_join_with_all_matching_keys(self, label_manager):
124+
"""Test join when all keys match"""
125+
loader = MockLoader({}, label_manager=label_manager)
126+
127+
batch = pa.RecordBatch.from_pydict(
128+
{
129+
'address': [
130+
'0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa',
131+
'0xbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb',
132+
'0xcccccccccccccccccccccccccccccccccccccccc',
133+
],
134+
'amount': [100, 200, 300],
135+
}
136+
)
137+
138+
joined_batch = loader._join_with_labels(batch, 'tokens', 'address', 'address')
139+
140+
# All 3 rows should be present
141+
assert joined_batch.num_rows == 3
142+
143+
def test_join_with_no_matching_keys(self, label_manager):
144+
"""Test join when no keys match"""
145+
loader = MockLoader({}, label_manager=label_manager)
146+
147+
batch = pa.RecordBatch.from_pydict(
148+
{
149+
'address': [
150+
'0xdddddddddddddddddddddddddddddddddddddddd',
151+
'0xeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee',
152+
'0xffffffffffffffffffffffffffffffffffffffff',
153+
],
154+
'amount': [100, 200, 300],
155+
}
156+
)
157+
158+
joined_batch = loader._join_with_labels(batch, 'tokens', 'address', 'address')
159+
160+
# Should have 0 rows (all filtered out)
161+
assert joined_batch.num_rows == 0
162+
163+
def test_join_invalid_label_name(self, label_manager):
164+
"""Test join with non-existent label"""
165+
loader = MockLoader({}, label_manager=label_manager)
166+
167+
batch = pa.RecordBatch.from_pydict({'address': ['0xA'], 'amount': [100]})
168+
169+
with pytest.raises(ValueError, match="Label 'nonexistent' not found"):
170+
loader._join_with_labels(batch, 'nonexistent', 'address', 'address')
171+
172+
def test_join_invalid_stream_key(self, label_manager):
173+
"""Test join with invalid stream key column"""
174+
loader = MockLoader({}, label_manager=label_manager)
175+
176+
batch = pa.RecordBatch.from_pydict({'address': ['0xA'], 'amount': [100]})
177+
178+
with pytest.raises(ValueError, match="Stream key column 'nonexistent' not found"):
179+
loader._join_with_labels(batch, 'tokens', 'address', 'nonexistent')
180+
181+
def test_join_invalid_label_key(self, label_manager):
182+
"""Test join with invalid label key column"""
183+
loader = MockLoader({}, label_manager=label_manager)
184+
185+
batch = pa.RecordBatch.from_pydict({'address': ['0xA'], 'amount': [100]})
186+
187+
with pytest.raises(ValueError, match="Label key column 'nonexistent' not found"):
188+
loader._join_with_labels(batch, 'tokens', 'nonexistent', 'address')
189+
190+
def test_join_no_label_manager(self):
191+
"""Test join when label manager not configured"""
192+
loader = MockLoader({}, label_manager=None)
193+
194+
batch = pa.RecordBatch.from_pydict({'address': ['0xA'], 'amount': [100]})
195+
196+
with pytest.raises(ValueError, match='Label manager not configured'):
197+
loader._join_with_labels(batch, 'tokens', 'address', 'address')

tests/unit/test_label_manager.py

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
"""Tests for LabelManager functionality"""
2+
3+
import tempfile
4+
from pathlib import Path
5+
6+
import pytest
7+
8+
from amp.config.label_manager import LabelManager
9+
10+
11+
class TestLabelManager:
12+
"""Test LabelManager class"""
13+
14+
def test_add_and_get_label(self):
15+
"""Test adding and retrieving a label dataset"""
16+
# Create a temporary CSV file with valid 40-char Ethereum addresses
17+
with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False) as f:
18+
f.write('address,symbol,name\n')
19+
f.write('0x1234567890123456789012345678901234567890,ETH,Ethereum\n')
20+
f.write('0xabcdefabcdefabcdefabcdefabcdefabcdefabcd,BTC,Bitcoin\n')
21+
csv_path = f.name
22+
23+
try:
24+
manager = LabelManager()
25+
26+
# Add label
27+
manager.add_label('tokens', csv_path)
28+
29+
# Get label
30+
label_table = manager.get_label('tokens')
31+
32+
assert label_table is not None
33+
assert label_table.num_rows == 2
34+
assert len(label_table.schema) == 3
35+
assert 'address' in label_table.schema.names
36+
assert 'symbol' in label_table.schema.names
37+
assert 'name' in label_table.schema.names
38+
39+
finally:
40+
Path(csv_path).unlink()
41+
42+
def test_get_nonexistent_label(self):
43+
"""Test getting a label that doesn't exist"""
44+
manager = LabelManager()
45+
label_table = manager.get_label('nonexistent')
46+
assert label_table is None
47+
48+
def test_list_labels(self):
49+
"""Test listing all configured labels"""
50+
with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False) as f:
51+
f.write('id,value\n')
52+
f.write('1,a\n')
53+
csv_path = f.name
54+
55+
try:
56+
manager = LabelManager()
57+
manager.add_label('test1', csv_path)
58+
manager.add_label('test2', csv_path)
59+
60+
labels = manager.list_labels()
61+
assert 'test1' in labels
62+
assert 'test2' in labels
63+
assert len(labels) == 2
64+
65+
finally:
66+
Path(csv_path).unlink()
67+
68+
def test_replace_label(self):
69+
"""Test replacing an existing label"""
70+
with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False) as f:
71+
f.write('id,value\n')
72+
f.write('1,a\n')
73+
f.write('2,b\n')
74+
csv_path1 = f.name
75+
76+
with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False) as f:
77+
f.write('id,value\n')
78+
f.write('1,x\n')
79+
csv_path2 = f.name
80+
81+
try:
82+
manager = LabelManager()
83+
manager.add_label('test', csv_path1)
84+
85+
# First version
86+
label1 = manager.get_label('test')
87+
assert label1.num_rows == 2
88+
89+
# Replace with new version
90+
manager.add_label('test', csv_path2)
91+
label2 = manager.get_label('test')
92+
assert label2.num_rows == 1
93+
94+
finally:
95+
Path(csv_path1).unlink()
96+
Path(csv_path2).unlink()
97+
98+
def test_remove_label(self):
99+
"""Test removing a label"""
100+
with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False) as f:
101+
f.write('id,value\n')
102+
f.write('1,a\n')
103+
csv_path = f.name
104+
105+
try:
106+
manager = LabelManager()
107+
manager.add_label('test', csv_path)
108+
109+
# Verify it exists
110+
assert manager.get_label('test') is not None
111+
112+
# Remove it
113+
result = manager.remove_label('test')
114+
assert result is True
115+
116+
# Verify it's gone
117+
assert manager.get_label('test') is None
118+
119+
# Try to remove again
120+
result = manager.remove_label('test')
121+
assert result is False
122+
123+
finally:
124+
Path(csv_path).unlink()
125+
126+
def test_clear_labels(self):
127+
"""Test clearing all labels"""
128+
with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False) as f:
129+
f.write('id,value\n')
130+
f.write('1,a\n')
131+
csv_path = f.name
132+
133+
try:
134+
manager = LabelManager()
135+
manager.add_label('test1', csv_path)
136+
manager.add_label('test2', csv_path)
137+
138+
assert len(manager.list_labels()) == 2
139+
140+
manager.clear()
141+
142+
assert len(manager.list_labels()) == 0
143+
144+
finally:
145+
Path(csv_path).unlink()
146+
147+
def test_invalid_csv_path(self):
148+
"""Test adding a label with invalid CSV path"""
149+
manager = LabelManager()
150+
151+
with pytest.raises(FileNotFoundError):
152+
manager.add_label('test', '/nonexistent/path.csv')

0 commit comments

Comments
 (0)