Skip to content

Commit 7392c9a

Browse files
committed
follow-up + tests
1 parent c9c38c5 commit 7392c9a

File tree

2 files changed

+151
-1
lines changed

2 files changed

+151
-1
lines changed

pum/sql_content.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,11 @@
1212

1313

1414
class CursorResult:
15-
"""A simple wrapper to hold cursor results after the cursor is closed."""
15+
"""A simple wrapper to hold cursor results after the cursor is closed.
16+
17+
This class provides a cursor-compatible interface for accessing query results
18+
after the actual database cursor has been closed.
19+
"""
1620

1721
def __init__(
1822
self, results: Optional[list] = None, description: Optional[Any] = None, rowcount: int = 0
@@ -22,6 +26,36 @@ def __init__(
2226
self._pum_rowcount = rowcount
2327
self._pum_index = 0
2428

29+
@property
30+
def description(self):
31+
"""Return the column description (compatible with cursor.description)."""
32+
return self._pum_description
33+
34+
@property
35+
def rowcount(self):
36+
"""Return the number of rows (compatible with cursor.rowcount)."""
37+
return self._pum_rowcount
38+
39+
def fetchall(self):
40+
"""Return all results (compatible with cursor.fetchall())."""
41+
return self._pum_results if self._pum_results is not None else []
42+
43+
def fetchone(self):
44+
"""Return the next result (compatible with cursor.fetchone())."""
45+
if self._pum_results is None or self._pum_index >= len(self._pum_results):
46+
return None
47+
result = self._pum_results[self._pum_index]
48+
self._pum_index += 1
49+
return result
50+
51+
def fetchmany(self, size: int = 1):
52+
"""Return the next `size` results (compatible with cursor.fetchmany())."""
53+
if self._pum_results is None:
54+
return []
55+
results = self._pum_results[self._pum_index : self._pum_index + size]
56+
self._pum_index += len(results)
57+
return results
58+
2559

2660
def sql_chunks_from_file(file: str | Path) -> list[psycopg.sql.SQL]:
2761
"""Read SQL from a file, remove comments, and split into chunks.

test/test_sql_content.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
import unittest
2+
3+
import psycopg
4+
5+
from pum.sql_content import SqlContent
6+
7+
8+
class TestSqlContent(unittest.TestCase):
9+
"""Test the SqlContent class and CursorResult."""
10+
11+
def setUp(self) -> None:
12+
"""Set up the test environment."""
13+
self.pg_service = "pum_test"
14+
15+
def test_cursor_result_fetchall(self) -> None:
16+
"""Test CursorResult.fetchall() returns all results."""
17+
with psycopg.connect(f"service={self.pg_service}") as conn:
18+
result = SqlContent("SELECT 1 AS a, 2 AS b UNION SELECT 3, 4").execute(conn)
19+
20+
# Test fetchall returns all rows
21+
rows = result.fetchall()
22+
self.assertEqual(len(rows), 2)
23+
self.assertEqual(rows[0], (1, 2))
24+
self.assertEqual(rows[1], (3, 4))
25+
26+
def test_cursor_result_fetchone(self) -> None:
27+
"""Test CursorResult.fetchone() returns rows one at a time."""
28+
with psycopg.connect(f"service={self.pg_service}") as conn:
29+
result = SqlContent("SELECT 1 AS a, 2 AS b UNION SELECT 3, 4 ORDER BY a").execute(conn)
30+
31+
# Test fetchone returns rows sequentially
32+
row1 = result.fetchone()
33+
self.assertEqual(row1, (1, 2))
34+
35+
row2 = result.fetchone()
36+
self.assertEqual(row2, (3, 4))
37+
38+
# Test fetchone returns None when no more rows
39+
row3 = result.fetchone()
40+
self.assertIsNone(row3)
41+
42+
def test_cursor_result_fetchmany(self) -> None:
43+
"""Test CursorResult.fetchmany() returns specified number of rows."""
44+
with psycopg.connect(f"service={self.pg_service}") as conn:
45+
result = SqlContent("SELECT generate_series(1, 5) AS n").execute(conn)
46+
47+
# Test fetchmany with size=2
48+
rows1 = result.fetchmany(2)
49+
self.assertEqual(len(rows1), 2)
50+
self.assertEqual(rows1[0], (1,))
51+
self.assertEqual(rows1[1], (2,))
52+
53+
# Test fetchmany with size=2 again
54+
rows2 = result.fetchmany(2)
55+
self.assertEqual(len(rows2), 2)
56+
self.assertEqual(rows2[0], (3,))
57+
self.assertEqual(rows2[1], (4,))
58+
59+
# Test fetchmany returns remaining rows when fewer than size
60+
rows3 = result.fetchmany(2)
61+
self.assertEqual(len(rows3), 1)
62+
self.assertEqual(rows3[0], (5,))
63+
64+
# Test fetchmany returns empty list when no more rows
65+
rows4 = result.fetchmany(2)
66+
self.assertEqual(rows4, [])
67+
68+
def test_cursor_result_description(self) -> None:
69+
"""Test CursorResult.description returns column information."""
70+
with psycopg.connect(f"service={self.pg_service}") as conn:
71+
result = SqlContent("SELECT 1 AS a, 'hello' AS b").execute(conn)
72+
73+
# Test description is accessible
74+
self.assertIsNotNone(result.description)
75+
self.assertEqual(len(result.description), 2)
76+
self.assertEqual(result.description[0][0], "a")
77+
self.assertEqual(result.description[1][0], "b")
78+
79+
def test_cursor_result_rowcount(self) -> None:
80+
"""Test CursorResult.rowcount returns number of rows."""
81+
with psycopg.connect(f"service={self.pg_service}") as conn:
82+
result = SqlContent("SELECT 1 UNION SELECT 2 UNION SELECT 3").execute(conn)
83+
84+
# Test rowcount is accessible
85+
self.assertEqual(result.rowcount, 3)
86+
87+
def test_cursor_result_no_results(self) -> None:
88+
"""Test CursorResult handles DDL statements with no results."""
89+
with psycopg.connect(f"service={self.pg_service}") as conn:
90+
# Create temporary table
91+
result = SqlContent("CREATE TEMP TABLE test_temp (id INT)").execute(conn)
92+
93+
# Test fetchall returns empty list for DDL
94+
rows = result.fetchall()
95+
self.assertEqual(rows, [])
96+
97+
# Test fetchone returns None for DDL
98+
row = result.fetchone()
99+
self.assertIsNone(row)
100+
101+
def test_cursor_result_internal_attributes(self) -> None:
102+
"""Test CursorResult internal attributes for backward compatibility."""
103+
with psycopg.connect(f"service={self.pg_service}") as conn:
104+
result = SqlContent("SELECT 1 AS a, 2 AS b").execute(conn)
105+
106+
# Test internal attributes used by existing code
107+
self.assertIsNotNone(result._pum_results)
108+
self.assertEqual(len(result._pum_results), 1)
109+
self.assertEqual(result._pum_results[0], (1, 2))
110+
111+
self.assertIsNotNone(result._pum_description)
112+
self.assertEqual(result._pum_description[0][0], "a")
113+
114+
115+
if __name__ == "__main__":
116+
unittest.main()

0 commit comments

Comments
 (0)