Skip to content

Commit a7a4e2d

Browse files
committed
Implement strict checking in one; #13
1 parent 7c69bcd commit a7a4e2d

File tree

2 files changed

+184
-12
lines changed

2 files changed

+184
-12
lines changed

postgres.py

Lines changed: 52 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,11 @@
2323
>>> db.run("INSERT INTO foo VALUES ('baz')")
2424
>>> db.run("INSERT INTO foo VALUES ('buz')")
2525
26-
Use :py:meth:`~postgres.Postgres.one` to fetch one result:
26+
Use :py:meth:`~postgres.Postgres.one` to fetch exactly one result:
2727
28-
>>> db.one("SELECT * FROM foo ORDER BY bar")
28+
>>> db.one("SELECT * FROM foo WHERE bar='baz'")
2929
{'bar': 'baz'}
30-
>>> db.one("SELECT * FROM foo WHERE bar='blam'")
31-
>>> # None
30+
3231
3332
Use :py:meth:`~postgres.Postgres.rows` to fetch all results:
3433
@@ -177,6 +176,19 @@ def url_to_dsn(url):
177176
return dsn
178177

179178

179+
# Exceptions
180+
# ==========
181+
182+
class NotOne(Exception):
183+
def __init__(self, rowcount):
184+
self.rowcount = rowcount
185+
def __str__(self):
186+
return "Got {} rows instead of 1.".format(self.rowcount)
187+
188+
class TooFew(NotOne): pass
189+
class TooMany(NotOne): pass
190+
191+
180192
# The Main Event
181193
# ==============
182194

@@ -186,6 +198,8 @@ class Postgres(object):
186198
:param unicode url: A ``postgres://`` URL or a `PostgreSQL connection string <http://www.postgresql.org/docs/current/static/libpq-connect.html>`_
187199
:param int minconn: The minimum size of the connection pool
188200
:param int maxconn: The minimum size of the connection pool
201+
:param strict_one: The default :py:attr:`strict` parameter for :py:meth:`~postgres.Postgres.one`
202+
:type strict_one: :py:class:`bool`
189203
190204
This is the main object that :py:mod:`postgres` provides, and you should
191205
have one instance per process for each PostgreSQL database your process
@@ -220,15 +234,20 @@ class Postgres(object):
220234
221235
"""
222236

223-
def __init__(self, url, minconn=1, maxconn=10):
237+
def __init__(self, url, minconn=1, maxconn=10, strict_one=None):
224238
if url.startswith("postgres://"):
225239
dsn = url_to_dsn(url)
240+
226241
self.pool = ConnectionPool( minconn=minconn
227242
, maxconn=maxconn
228243
, dsn=dsn
229244
, connection_factory=Connection
230245
)
231246

247+
if strict_one not in (True, False, None):
248+
raise ValueError("strict_one must be True, False, or None.")
249+
self.strict_one = strict_one
250+
232251
def run(self, sql, parameters=None):
233252
"""Execute a query and discard any results.
234253
@@ -245,21 +264,48 @@ def run(self, sql, parameters=None):
245264
with self.get_cursor() as cursor:
246265
cursor.execute(sql, parameters)
247266

248-
def one(self, sql, parameters=None):
267+
def one(self, sql, parameters=None, strict=None):
249268
"""Execute a query and return a single result.
250269
251270
:param unicode sql: the SQL statement to execute
252271
:param parameters: the bind parameters for the SQL statement
253272
:type parameters: dict or tuple
273+
:param strict: whether to raise when there isn't exactly one result
274+
:type strict: :py:class:`bool`
254275
:returns: :py:class:`dict` or :py:const:`None`
276+
:raises: :py:exc:`~postgres.TooFew` or :py:exc:`~postgres.TooMany`
277+
278+
By default, :py:attr:`strict` ends up evaluating to :py:class:`True`,
279+
in which case we raise :py:exc:`~postgres.TooFew` or
280+
:py:exc:`~postgres.TooMany` if the number of rows returned isn't
281+
exactly one. You can override this behavior per-call with the
282+
:py:attr:`strict` argument here, or globally by passing
283+
:py:attr:`strict_one` to the :py:class:`~postgres.Postgres`
284+
constructor. If you use both, the :py:attr:`strict` argument here wins.
255285
256286
>>> row = db.one("SELECT * FROM foo WHERE bar='baz'"):
257287
>>> print(row["bar"])
258288
baz
259289
260290
"""
291+
if strict not in (True, False, None):
292+
raise ValueError("strict must be True, False, or None.")
293+
294+
if strict is None:
295+
if self.strict_one is None:
296+
strict = True # library default
297+
else:
298+
strict = self.strict_one # user default
299+
261300
with self.get_cursor() as cursor:
262301
cursor.execute(sql, parameters)
302+
303+
if strict:
304+
if cursor.rowcount < 1:
305+
raise TooFew(cursor.rowcount)
306+
elif cursor.rowcount > 1:
307+
raise TooMany(cursor.rowcount)
308+
263309
return cursor.fetchone()
264310

265311
def rows(self, sql, parameters=None):

tests.py

Lines changed: 132 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,15 @@
33
import os
44
from unittest import TestCase
55

6-
from postgres import Postgres
6+
from postgres import Postgres, TooFew, TooMany
77

88

99
DATABASE_URL = os.environ['DATABASE_URL']
1010

1111

12+
# harnesses
13+
# =========
14+
1215
class WithSchema(TestCase):
1316

1417
def setUp(self):
@@ -18,6 +21,7 @@ def setUp(self):
1821

1922
def tearDown(self):
2023
self.db.run("DROP SCHEMA IF EXISTS public CASCADE")
24+
del self.db
2125

2226

2327
class WithData(WithSchema):
@@ -29,6 +33,9 @@ def setUp(self):
2933
self.db.run("INSERT INTO foo VALUES ('buz')")
3034

3135

36+
# db.run
37+
# ======
38+
3239
class TestRun(WithSchema):
3340

3441
def test_run_runs(self):
@@ -44,16 +51,126 @@ def test_run_inserts(self):
4451
assert actual == 1
4552

4653

47-
class TestOneAndRows(WithData):
54+
# db.one
55+
# ======
56+
# With all the combinations of strict_one and strict, we end up with a number
57+
# of tests here. Since the behavior of the one method with a strict parameter
58+
# of True or False is expected to be the same regardless of what strict_one is
59+
# set to, we can write those once and then use the TestOne TestCase as the base
60+
# class for other TestCases that vary the strict_one attribute. The TestOne
61+
# tests will be re-run in each new context.
4862

49-
def test_one_fetches_the_first_one(self):
50-
actual = self.db.one("SELECT * FROM foo ORDER BY bar")
63+
class TestNotOneException(WithData):
64+
65+
def test_TooFew_message_is_helpful(self):
66+
try:
67+
self.db.one("SELECT * FROM foo WHERE bar='blah'", strict=True)
68+
except TooFew, exc:
69+
actual = str(exc)
70+
assert actual == "Got 0 rows instead of 1."
71+
72+
def test_TooMany_message_is_helpful(self):
73+
try:
74+
self.db.one("SELECT * FROM foo", strict=True)
75+
except TooMany, exc:
76+
actual = str(exc)
77+
assert actual == "Got 2 rows instead of 1."
78+
79+
80+
class TestOne(WithData):
81+
82+
def test_with_strict_True_one_raises_TooFew(self):
83+
self.assertRaises( TooFew
84+
, self.db.one
85+
, "SELECT * FROM foo WHERE bar='blah'"
86+
, strict=True
87+
)
88+
89+
def test_with_strict_True_one_fetches_the_one(self):
90+
actual = self.db.one("SELECT * FROM foo WHERE bar='baz'", strict=True)
91+
assert actual == {"bar": "baz"}
92+
93+
def test_with_strict_True_one_raises_TooMany(self):
94+
self.assertRaises( TooMany
95+
, self.db.one
96+
, "SELECT * FROM foo"
97+
, strict=True
98+
)
99+
100+
101+
def test_with_strict_False_one_returns_None_if_theres_none(self):
102+
actual = self.db.one("SELECT * FROM foo WHERE bar='nun'", strict=False)
103+
assert actual is None
104+
105+
def test_with_strict_False_one_fetches_the_first_one(self):
106+
actual = self.db.one("SELECT * FROM foo ORDER BY bar", strict=False)
107+
assert actual == {"bar": "baz"}
108+
109+
110+
class TestOne_StrictOneNone(TestOne):
111+
112+
def setUp(self):
113+
WithData.setUp(self)
114+
self.db.strict_one = None
115+
116+
def test_one_raises_TooFew(self):
117+
self.assertRaises( TooFew
118+
, self.db.one
119+
, "SELECT * FROM foo WHERE bar='nun'"
120+
)
121+
122+
def test_one_returns_one(self):
123+
actual = self.db.one("SELECT * FROM foo WHERE bar='baz'")
51124
assert actual == {"bar": "baz"}
52125

53-
def test_one_returns_None_if_theres_none(self):
54-
actual = self.db.one("SELECT * FROM foo WHERE bar='blam'")
126+
def test_one_raises_TooMany(self):
127+
self.assertRaises(TooMany, self.db.one, "SELECT * FROM foo")
128+
129+
130+
class TestOne_StrictOneFalse(TestOne):
131+
132+
def setUp(self):
133+
WithData.setUp(self)
134+
self.db.strict_one = False
135+
136+
def test_one_returns_None(self):
137+
actual = self.db.one("SELECT * FROM foo WHERE bar='nun'")
55138
assert actual is None
56139

140+
def test_one_returns_one(self):
141+
actual = self.db.one("SELECT * FROM foo WHERE bar='baz'")
142+
assert actual == {"bar": "baz"}
143+
144+
def test_one_returns_first_one(self):
145+
actual = self.db.one("SELECT * FROM foo ORDER BY bar")
146+
assert actual == {"bar": "baz"}
147+
148+
149+
class TestOne_StrictOneTrue(TestOne):
150+
151+
def setUp(self):
152+
WithData.setUp(self)
153+
self.db.strict_one = True
154+
155+
def test_one_raises_TooFew(self):
156+
self.assertRaises( TooFew
157+
, self.db.one
158+
, "SELECT * FROM foo WHERE bar='nun'"
159+
)
160+
161+
def test_one_returns_one(self):
162+
actual = self.db.one("SELECT * FROM foo WHERE bar='baz'")
163+
assert actual == {"bar": "baz"}
164+
165+
def test_one_raises_TooMany(self):
166+
self.assertRaises(TooMany, self.db.one, "SELECT * FROM foo")
167+
168+
169+
# db.rows
170+
# =======
171+
172+
class TestRows(WithData):
173+
57174
def test_rows_fetches_all_rows(self):
58175
actual = self.db.rows("SELECT * FROM foo ORDER BY bar")
59176
assert actual == [{"bar": "baz"}, {"bar": "buz"}]
@@ -68,6 +185,9 @@ def test_bind_parameters_as_tuple_work(self):
68185
assert actual == [{"bar": "baz"}]
69186

70187

188+
# db.get_cursor
189+
# =============
190+
71191
class TestCursor(WithData):
72192

73193
def test_get_cursor_gets_a_cursor(self):
@@ -89,6 +209,9 @@ def test_we_can_use_cursor_closed(self):
89209
assert not actual
90210

91211

212+
# db.get_transaction
213+
# ==================
214+
92215
class TestTransaction(WithData):
93216

94217
def test_get_transaction_gets_a_transaction(self):
@@ -125,6 +248,9 @@ class Heck(Exception): pass
125248
assert actual == [{"bar": "baz"}, {"bar": "buz"}]
126249

127250

251+
# db.get_connection
252+
# =================
253+
128254
class TestConnection(WithData):
129255

130256
def test_get_connection_gets_a_connection(self):

0 commit comments

Comments
 (0)