Skip to content

Commit ec27e15

Browse files
committed
Add base infrastructure for loader test generalization
Foundational work to enable all loader tests to inherit common test patterns
1 parent 8ce25f1 commit ec27e15

File tree

3 files changed

+652
-0
lines changed

3 files changed

+652
-0
lines changed
Lines changed: 345 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,345 @@
1+
"""
2+
PostgreSQL-specific loader integration tests.
3+
4+
This module provides PostgreSQL-specific test configuration and tests that
5+
inherit from the generalized base test classes.
6+
"""
7+
8+
import time
9+
from typing import Any, Dict, List, Optional
10+
11+
import pytest
12+
13+
try:
14+
from src.amp.loaders.implementations.postgresql_loader import PostgreSQLLoader
15+
from tests.integration.loaders.conftest import LoaderTestConfig
16+
from tests.integration.loaders.test_base_loader import BaseLoaderTests
17+
except ImportError:
18+
pytest.skip('amp modules not available', allow_module_level=True)
19+
20+
21+
class PostgreSQLTestConfig(LoaderTestConfig):
22+
"""PostgreSQL-specific test configuration"""
23+
24+
loader_class = PostgreSQLLoader
25+
config_fixture_name = 'postgresql_test_config'
26+
27+
supports_overwrite = True
28+
supports_streaming = True
29+
supports_multi_network = True
30+
supports_null_values = True
31+
32+
def get_row_count(self, loader: PostgreSQLLoader, table_name: str) -> int:
33+
"""Get row count from PostgreSQL table"""
34+
conn = loader.pool.getconn()
35+
try:
36+
with conn.cursor() as cur:
37+
cur.execute(f'SELECT COUNT(*) FROM {table_name}')
38+
return cur.fetchone()[0]
39+
finally:
40+
loader.pool.putconn(conn)
41+
42+
def query_rows(
43+
self, loader: PostgreSQLLoader, table_name: str, where: Optional[str] = None, order_by: Optional[str] = None
44+
) -> List[Dict[str, Any]]:
45+
"""Query rows from PostgreSQL table"""
46+
conn = loader.pool.getconn()
47+
try:
48+
with conn.cursor() as cur:
49+
# Get column names first
50+
cur.execute(
51+
"""
52+
SELECT column_name
53+
FROM information_schema.columns
54+
WHERE table_name = %s
55+
ORDER BY ordinal_position
56+
""",
57+
(table_name,),
58+
)
59+
columns = [row[0] for row in cur.fetchall()]
60+
61+
# Build query
62+
query = f'SELECT * FROM {table_name}'
63+
if where:
64+
query += f' WHERE {where}'
65+
if order_by:
66+
query += f' ORDER BY {order_by}'
67+
68+
cur.execute(query)
69+
rows = cur.fetchall()
70+
71+
# Convert to list of dicts
72+
return [dict(zip(columns, row)) for row in rows]
73+
finally:
74+
loader.pool.putconn(conn)
75+
76+
def cleanup_table(self, loader: PostgreSQLLoader, table_name: str) -> None:
77+
"""Drop PostgreSQL table"""
78+
conn = loader.pool.getconn()
79+
try:
80+
with conn.cursor() as cur:
81+
cur.execute(f'DROP TABLE IF EXISTS {table_name} CASCADE')
82+
conn.commit()
83+
finally:
84+
loader.pool.putconn(conn)
85+
86+
def get_column_names(self, loader: PostgreSQLLoader, table_name: str) -> List[str]:
87+
"""Get column names from PostgreSQL table"""
88+
conn = loader.pool.getconn()
89+
try:
90+
with conn.cursor() as cur:
91+
cur.execute(
92+
"""
93+
SELECT column_name
94+
FROM information_schema.columns
95+
WHERE table_name = %s
96+
ORDER BY ordinal_position
97+
""",
98+
(table_name,),
99+
)
100+
return [row[0] for row in cur.fetchall()]
101+
finally:
102+
loader.pool.putconn(conn)
103+
104+
105+
@pytest.mark.postgresql
106+
class TestPostgreSQLCore(BaseLoaderTests):
107+
"""PostgreSQL core loader tests (inherited from base)"""
108+
109+
config = PostgreSQLTestConfig()
110+
111+
112+
@pytest.fixture
113+
def cleanup_tables(postgresql_test_config):
114+
"""Cleanup test tables after tests"""
115+
tables_to_clean = []
116+
117+
yield tables_to_clean
118+
119+
# Cleanup
120+
loader = PostgreSQLLoader(postgresql_test_config)
121+
try:
122+
loader.connect()
123+
conn = loader.pool.getconn()
124+
try:
125+
with conn.cursor() as cur:
126+
for table in tables_to_clean:
127+
try:
128+
cur.execute(f'DROP TABLE IF EXISTS {table} CASCADE')
129+
conn.commit()
130+
except Exception:
131+
pass
132+
finally:
133+
loader.pool.putconn(conn)
134+
loader.disconnect()
135+
except Exception:
136+
pass
137+
138+
139+
@pytest.mark.postgresql
140+
class TestPostgreSQLSpecific:
141+
"""PostgreSQL-specific tests that cannot be generalized"""
142+
143+
def test_connection_pooling(self, postgresql_test_config, small_test_data, test_table_name, cleanup_tables):
144+
"""Test PostgreSQL connection pooling behavior"""
145+
from src.amp.loaders.base import LoadMode
146+
147+
cleanup_tables.append(test_table_name)
148+
149+
loader = PostgreSQLLoader(postgresql_test_config)
150+
151+
with loader:
152+
# Perform multiple operations to test pool reuse
153+
for i in range(5):
154+
subset = small_test_data.slice(i, 1)
155+
mode = LoadMode.OVERWRITE if i == 0 else LoadMode.APPEND
156+
157+
result = loader.load_table(subset, test_table_name, mode=mode)
158+
assert result.success == True
159+
160+
# Verify pool is managing connections properly
161+
# Note: _used is a dict in ThreadedConnectionPool, not an int
162+
assert len(loader.pool._used) <= loader.pool.maxconn
163+
164+
def test_binary_data_handling(self, postgresql_test_config, test_table_name, cleanup_tables):
165+
"""Test binary data handling with INSERT fallback"""
166+
import pyarrow as pa
167+
168+
cleanup_tables.append(test_table_name)
169+
170+
# Create data with binary columns
171+
data = {'id': [1, 2, 3], 'binary_data': [b'hello', b'world', b'test'], 'text_data': ['a', 'b', 'c']}
172+
table = pa.Table.from_pydict(data)
173+
174+
loader = PostgreSQLLoader(postgresql_test_config)
175+
176+
with loader:
177+
result = loader.load_table(table, test_table_name)
178+
assert result.success == True
179+
assert result.rows_loaded == 3
180+
181+
# Verify binary data was stored correctly
182+
conn = loader.pool.getconn()
183+
try:
184+
with conn.cursor() as cur:
185+
cur.execute(f'SELECT id, binary_data FROM {test_table_name} ORDER BY id')
186+
rows = cur.fetchall()
187+
assert rows[0][1].tobytes() == b'hello'
188+
assert rows[1][1].tobytes() == b'world'
189+
assert rows[2][1].tobytes() == b'test'
190+
finally:
191+
loader.pool.putconn(conn)
192+
193+
def test_schema_retrieval(self, postgresql_test_config, small_test_data, test_table_name, cleanup_tables):
194+
"""Test schema retrieval functionality"""
195+
cleanup_tables.append(test_table_name)
196+
197+
loader = PostgreSQLLoader(postgresql_test_config)
198+
199+
with loader:
200+
# Create table
201+
result = loader.load_table(small_test_data, test_table_name)
202+
assert result.success == True
203+
204+
# Get schema
205+
schema = loader.get_table_schema(test_table_name)
206+
assert schema is not None
207+
208+
# Filter out metadata columns added by PostgreSQL loader
209+
non_meta_fields = [
210+
field for field in schema if not (field.name.startswith('_meta_') or field.name.startswith('_amp_'))
211+
]
212+
213+
assert len(non_meta_fields) == len(small_test_data.schema)
214+
215+
# Verify column names match (excluding metadata columns)
216+
original_names = set(small_test_data.schema.names)
217+
retrieved_names = set(field.name for field in non_meta_fields)
218+
assert original_names == retrieved_names
219+
220+
def test_performance_metrics(self, postgresql_test_config, medium_test_table, test_table_name, cleanup_tables):
221+
"""Test performance metrics in results"""
222+
cleanup_tables.append(test_table_name)
223+
224+
loader = PostgreSQLLoader(postgresql_test_config)
225+
226+
with loader:
227+
start_time = time.time()
228+
result = loader.load_table(medium_test_table, test_table_name)
229+
end_time = time.time()
230+
231+
assert result.success == True
232+
assert result.duration > 0
233+
assert result.duration <= (end_time - start_time)
234+
assert result.rows_loaded == 10000
235+
236+
# Check metadata contains performance info
237+
assert 'table_size_bytes' in result.metadata
238+
assert result.metadata['table_size_bytes'] > 0
239+
240+
def test_null_value_handling_detailed(
241+
self, postgresql_test_config, null_test_data, test_table_name, cleanup_tables
242+
):
243+
"""Test comprehensive null value handling across all PostgreSQL data types"""
244+
cleanup_tables.append(test_table_name)
245+
246+
loader = PostgreSQLLoader(postgresql_test_config)
247+
248+
with loader:
249+
result = loader.load_table(null_test_data, test_table_name)
250+
assert result.success == True
251+
assert result.rows_loaded == 10
252+
253+
conn = loader.pool.getconn()
254+
try:
255+
with conn.cursor() as cur:
256+
# Check text field nulls (rows 3, 6, 9 have index 2, 5, 8)
257+
cur.execute(f'SELECT COUNT(*) FROM {test_table_name} WHERE text_field IS NULL')
258+
text_nulls = cur.fetchone()[0]
259+
assert text_nulls == 3
260+
261+
# Check int field nulls (rows 2, 5, 8 have index 1, 4, 7)
262+
cur.execute(f'SELECT COUNT(*) FROM {test_table_name} WHERE int_field IS NULL')
263+
int_nulls = cur.fetchone()[0]
264+
assert int_nulls == 3
265+
266+
# Check float field nulls (rows 3, 6, 9 have index 2, 5, 8)
267+
cur.execute(f'SELECT COUNT(*) FROM {test_table_name} WHERE float_field IS NULL')
268+
float_nulls = cur.fetchone()[0]
269+
assert float_nulls == 3
270+
271+
# Check bool field nulls (rows 3, 6, 9 have index 2, 5, 8)
272+
cur.execute(f'SELECT COUNT(*) FROM {test_table_name} WHERE bool_field IS NULL')
273+
bool_nulls = cur.fetchone()[0]
274+
assert bool_nulls == 3
275+
276+
# Check timestamp field nulls
277+
cur.execute(f'SELECT COUNT(*) FROM {test_table_name} WHERE timestamp_field IS NULL')
278+
timestamp_nulls = cur.fetchone()[0]
279+
assert timestamp_nulls == 4
280+
281+
# Check json field nulls
282+
cur.execute(f'SELECT COUNT(*) FROM {test_table_name} WHERE json_field IS NULL')
283+
json_nulls = cur.fetchone()[0]
284+
assert json_nulls == 3
285+
286+
# Verify non-null values are intact
287+
cur.execute(f'SELECT text_field FROM {test_table_name} WHERE id = 1')
288+
text_val = cur.fetchone()[0]
289+
assert text_val in ['a', '"a"'] # Handle potential CSV quoting
290+
291+
cur.execute(f'SELECT int_field FROM {test_table_name} WHERE id = 1')
292+
int_val = cur.fetchone()[0]
293+
assert int_val == 1
294+
295+
cur.execute(f'SELECT float_field FROM {test_table_name} WHERE id = 1')
296+
float_val = cur.fetchone()[0]
297+
assert abs(float_val - 1.1) < 0.01
298+
299+
cur.execute(f'SELECT bool_field FROM {test_table_name} WHERE id = 1')
300+
bool_val = cur.fetchone()[0]
301+
assert bool_val == True
302+
finally:
303+
loader.pool.putconn(conn)
304+
305+
306+
@pytest.mark.postgresql
307+
@pytest.mark.slow
308+
class TestPostgreSQLPerformance:
309+
"""PostgreSQL performance tests"""
310+
311+
def test_large_data_loading(self, postgresql_test_config, test_table_name, cleanup_tables):
312+
"""Test loading large datasets"""
313+
import pyarrow as pa
314+
from datetime import datetime
315+
316+
cleanup_tables.append(test_table_name)
317+
318+
# Create large dataset
319+
large_data = {
320+
'id': list(range(50000)),
321+
'value': [i * 0.123 for i in range(50000)],
322+
'category': [f'category_{i % 100}' for i in range(50000)],
323+
'description': [f'This is a longer text description for row {i}' for i in range(50000)],
324+
'created_at': [datetime.now() for _ in range(50000)],
325+
}
326+
large_table = pa.Table.from_pydict(large_data)
327+
328+
loader = PostgreSQLLoader(postgresql_test_config)
329+
330+
with loader:
331+
result = loader.load_table(large_table, test_table_name)
332+
333+
assert result.success == True
334+
assert result.rows_loaded == 50000
335+
assert result.duration < 60 # Should complete within 60 seconds
336+
337+
# Verify data integrity
338+
conn = loader.pool.getconn()
339+
try:
340+
with conn.cursor() as cur:
341+
cur.execute(f'SELECT COUNT(*) FROM {test_table_name}')
342+
count = cur.fetchone()[0]
343+
assert count == 50000
344+
finally:
345+
loader.pool.putconn(conn)

0 commit comments

Comments
 (0)