Skip to content

Commit bc8c0e9

Browse files
committed
Improve test coverage for codecs, fix a few bugs found in the process
1 parent b68bfdc commit bc8c0e9

File tree

10 files changed

+178
-32
lines changed

10 files changed

+178
-32
lines changed

.coveragerc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
[run]
2+
branch = True
3+
plugins = Cython.Coverage
4+
source =
5+
asyncpg/
6+
tests/
7+
omit =
8+
*.pxd
9+
10+
[paths]
11+
source =
12+
asyncpg

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,5 @@ __pycache__/
2727
/dist
2828
/.cache
2929
docs/_build
30+
*,cover
31+
.coverage

asyncpg/prepared_stmt.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,7 @@ async def explain(self, *args, analyze=False):
9999
return json.loads(data)
100100

101101
async def fetch(self, *args, timeout=None):
102-
r"""Execute the statement and return the results as a list \
103-
of :class:`Record` objects.
102+
r"""Execute the statement and return a list of :class:`Record` objects.
104103
105104
:param str query: Query text
106105
:param args: Query arguments

asyncpg/protocol/codecs/__init__.py

Whitespace-only changes.

asyncpg/protocol/codecs/base.pyx

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ cdef class Codec:
2929
self.kind = kind
3030
self.type = type
3131
self.format = format
32-
self.type = type
3332
self.c_encoder = c_encoder
3433
self.c_decoder = c_decoder
3534
self.py_encoder = py_encoder
@@ -92,15 +91,17 @@ cdef class Codec:
9291
cdef:
9392
WriteBuffer elem_data
9493
int32_t i
94+
list elem_codecs = self.element_codecs
9595

9696
elem_data = WriteBuffer.new()
9797
i = 0
9898
for item in obj:
99-
elem_data.write_int32(self.element_type_ids[i])
99+
elem_data.write_int32(self.element_type_oids[i])
100100
if item is None:
101101
elem_data.write_int32(-1)
102102
else:
103-
self.element_codecs[i].encode(settings, elem_data, item)
103+
(<Codec>elem_codecs[i]).encode(settings, elem_data, item)
104+
i += 1
104105

105106
record_encode_frame(settings, buf, elem_data, len(obj))
106107

asyncpg/protocol/codecs/bytea.pyx

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,6 @@
55
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
66

77

8-
import decimal
9-
10-
_Dec = decimal.Decimal
11-
12-
138
cdef bytea_encode(ConnectionSettings settings, WriteBuffer wbuf, obj):
149
cdef:
1510
Py_buffer pybuf

asyncpg/protocol/codecs/hstore.pyx

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ cdef hstore_encode(ConnectionSettings settings, WriteBuffer buf, obj):
2222
items = obj
2323

2424
for k, v in items:
25+
if k is None:
26+
raise ValueError('null value not allowed in hstore key')
2527
as_pg_string_and_size(settings, k, &str, &size)
2628
item_buf.write_int32(<int32_t>size)
2729
item_buf.write_cstr(str, size)

asyncpg/protocol/codecs/misc.pyx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
66

77

8-
cdef void_encode(ConnectionSettings settings, FastReadBuffer buf):
8+
cdef void_encode(ConnectionSettings settings, WriteBuffer buf, obj):
99
# Void is zero bytes
1010
buf.write_int32(0)
1111

asyncpg/protocol/codecs/text.pyx

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,16 @@
66

77

88
cdef inline as_pg_string_and_size(
9-
ConnectionSettings settings, obj, char **str, ssize_t *size):
9+
ConnectionSettings settings, obj, char **cstr, ssize_t *size):
10+
11+
if not cpython.PyUnicode_Check(obj):
12+
obj = str(obj)
1013

1114
if settings.is_encoding_utf8():
12-
str[0] = PyUnicode_AsUTF8AndSize(obj, size)
15+
cstr[0] = PyUnicode_AsUTF8AndSize(obj, size)
1316
else:
1417
encoded = settings.get_text_codec().encode(obj)
15-
cpython.PyBytes_AsStringAndSize(encoded, str, size)
18+
cpython.PyBytes_AsStringAndSize(encoded, cstr, size)
1619

1720
if size[0] > 0x7fffffff:
1821
raise ValueError('string too long')

tests/test_codecs.py

Lines changed: 150 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import ipaddress
1111
import math
1212
import random
13+
import struct
1314
import uuid
1415

1516
import asyncpg
@@ -31,6 +32,9 @@ def _timezone(offset):
3132

3233

3334
type_samples = [
35+
('bool', 'bool', (
36+
True, False,
37+
)),
3438
('smallint', 'int2', (
3539
-2 ** 15 + 1, 2 ** 15 - 1,
3640
-1, 0, 1,
@@ -132,7 +136,8 @@ def _timezone(offset):
132136
bytes(range(255, -1, -1)),
133137
b'\x00\x00',
134138
b'foo',
135-
b'f' * 1024 * 1024
139+
b'f' * 1024 * 1024,
140+
dict(input=bytearray(b'\x02\x01'), output=b'\x02\x01'),
136141
)),
137142
('text', 'text', (
138143
'',
@@ -156,6 +161,7 @@ def _timezone(offset):
156161
datetime.date(2000, 1, 1),
157162
datetime.date(500, 1, 1),
158163
datetime.date(1, 1, 1),
164+
infinity_date,
159165
]),
160166
('time', 'time', [
161167
datetime.time(12, 15, 20),
@@ -191,7 +197,9 @@ def _timezone(offset):
191197
]),
192198
('uuid', 'uuid', [
193199
uuid.UUID('38a4ff5a-3a56-11e6-a6c2-c8f73323c6d4'),
194-
uuid.UUID('00000000-0000-0000-0000-000000000000')
200+
uuid.UUID('00000000-0000-0000-0000-000000000000'),
201+
{'input': '00000000-0000-0000-0000-000000000000',
202+
'output': uuid.UUID('00000000-0000-0000-0000-000000000000')}
195203
]),
196204
('uuid[]', 'uuid[]', [
197205
(uuid.UUID('38a4ff5a-3a56-11e6-a6c2-c8f73323c6d4'),
@@ -294,11 +302,21 @@ def _timezone(offset):
294302
asyncpg.BitString(),
295303
asyncpg.BitString.frombytes(b'\x00', bitlength=3),
296304
asyncpg.BitString('0000 0000 1'),
305+
dict(input=b'\x01', output=asyncpg.BitString('0000 0001')),
306+
dict(input=bytearray(b'\x02'), output=asyncpg.BitString('0000 0010')),
297307
]),
298308
('path', 'path', [
299309
asyncpg.Path(asyncpg.Point(0.0, 0.0), asyncpg.Point(1.0, 1.0)),
300310
asyncpg.Path(asyncpg.Point(0.0, 0.0), asyncpg.Point(1.0, 1.0),
301311
is_closed=True),
312+
dict(input=((0.0, 0.0), (1.0, 1.0)),
313+
output=asyncpg.Path(asyncpg.Point(0.0, 0.0),
314+
asyncpg.Point(1.0, 1.0),
315+
is_closed=True)),
316+
dict(input=[(0.0, 0.0), (1.0, 1.0)],
317+
output=asyncpg.Path(asyncpg.Point(0.0, 0.0),
318+
asyncpg.Point(1.0, 1.0),
319+
is_closed=False)),
302320
]),
303321
('point', 'point', [
304322
asyncpg.Point(0.0, 0.0),
@@ -334,22 +352,28 @@ async def test_standard_codecs(self):
334352

335353
for sample in sample_data:
336354
with self.subTest(sample=sample, typname=typname):
337-
rsample = await st.fetchval(sample)
355+
if isinstance(sample, dict):
356+
inputval = sample['input']
357+
outputval = sample['output']
358+
else:
359+
inputval = outputval = sample
360+
361+
result = await st.fetchval(inputval)
338362
err_msg = (
339-
"failed to return {} object data as-is; "
340-
"gave {!r}, received {!r}".format(
341-
typname, sample, rsample))
363+
"unexpected result for {} when passing {!r}: "
364+
"received {!r}, expected {!r}".format(
365+
typname, inputval, result, outputval))
342366

343367
if typname.startswith('float'):
344-
if math.isnan(sample):
345-
if not math.isnan(rsample):
368+
if math.isnan(outputval):
369+
if not math.isnan(result):
346370
self.fail(err_msg)
347371
else:
348372
self.assertTrue(
349-
math.isclose(rsample, sample, rel_tol=1e-6),
373+
math.isclose(result, outputval, rel_tol=1e-6),
350374
err_msg)
351375
else:
352-
self.assertEqual(rsample, sample, err_msg)
376+
self.assertEqual(result, outputval, err_msg)
353377

354378
with self.subTest(sample=None, typname=typname):
355379
# Test that None is handled for all types.
@@ -369,10 +393,9 @@ async def test_all_builtin_types_handled(self):
369393
'core type {} ({}) is unhandled'.format(typename, oid))
370394

371395
async def test_void(self):
372-
stmt = await self.con.prepare('select pg_sleep(0)')
373-
self.assertIsNone(await stmt.fetchval())
374-
375-
await self.con.fetchval('select now($1::void)', None)
396+
res = await self.con.fetchval('select pg_sleep(0)')
397+
self.assertIsNone(res)
398+
await self.con.fetchval('select now($1::void)', '')
376399

377400
def test_bitstring(self):
378401
bitlen = random.randint(0, 1000)
@@ -424,6 +447,10 @@ async def test_invalid_input(self):
424447
32768,
425448
-32768
426449
]),
450+
('float4', ValueError, 'float value too large', [
451+
4.1 * 10 ** 40,
452+
-4.1 * 10 ** 40,
453+
]),
427454
('int4', TypeError, 'an integer is required', [
428455
'2',
429456
'aa',
@@ -452,7 +479,11 @@ async def test_arrays(self):
452479
(
453480
r"SELECT '{{{{{{1}}}}}}'::int[]",
454481
((((((1,),),),),),)
455-
)
482+
),
483+
(
484+
r"SELECT '{1, 2, NULL}'::int[]::anyarray",
485+
(1, 2, None)
486+
),
456487
]
457488

458489
for sql, expected in cases:
@@ -464,6 +495,7 @@ async def test_arrays(self):
464495
await self.con.fetchval("SELECT '{{{{{{{1}}}}}}}'::int[]")
465496

466497
cases = [
498+
(None,),
467499
(1, 2, 3, 4, 5, 6),
468500
((1, 2), (4, 5), (6, 7)),
469501
(((1,), (2,)), ((4,), (5,)), ((None,), (7,))),
@@ -559,6 +591,10 @@ async def test_composites(self):
559591
self.assertEqual(at[0].type.name, 'test_composite')
560592
self.assertEqual(at[0].type.kind, 'composite')
561593

594+
res = await self.con.fetchval('''
595+
SELECT $1::test_composite
596+
''', res)
597+
562598
finally:
563599
await self.con.execute('DROP TYPE test_composite')
564600

@@ -645,13 +681,29 @@ async def test_extra_codec_alias(self):
645681
await self.con.set_builtin_type_codec(
646682
'hstore', codec_name='pg_contrib.hstore')
647683

684+
cases = [
685+
{'ham': 'spam', 'nada': None},
686+
{}
687+
]
688+
648689
st = await self.con.prepare('''
649690
SELECT $1::hstore AS result
650691
''')
651-
res = await st.fetchrow({'ham': 'spam', 'nada': None})
652-
res = res['result']
653692

654-
self.assertEqual(res, {'ham': 'spam', 'nada': None})
693+
for case in cases:
694+
res = await st.fetchval(case)
695+
self.assertEqual(res, case)
696+
697+
res = await self.con.fetchval('''
698+
SELECT $1::hstore AS result
699+
''', (('foo', 2), ('bar', 3)))
700+
701+
self.assertEqual(res, {'foo': '2', 'bar': '3'})
702+
703+
with self.assertRaisesRegex(ValueError, 'null value not allowed'):
704+
await self.con.fetchval('''
705+
SELECT $1::hstore AS result
706+
''', {None: '1'})
655707

656708
finally:
657709
await self.con.execute('''
@@ -728,3 +780,83 @@ def hstore_encoder(obj):
728780
await self.con.execute('''
729781
DROP EXTENSION hstore
730782
''')
783+
784+
async def test_custom_codec_binary(self):
785+
"""Test encoding/decoding using a custom codec in binary mode."""
786+
await self.con.execute('''
787+
CREATE EXTENSION IF NOT EXISTS hstore
788+
''')
789+
790+
longstruct = struct.Struct('!L')
791+
ulong_unpack = lambda b: longstruct.unpack_from(b)[0]
792+
ulong_pack = longstruct.pack
793+
794+
def hstore_decoder(data):
795+
result = {}
796+
n = ulong_unpack(data)
797+
view = memoryview(data)
798+
ptr = 4
799+
800+
for i in range(n):
801+
klen = ulong_unpack(view[ptr:ptr + 4])
802+
ptr += 4
803+
k = bytes(view[ptr:ptr + klen]).decode()
804+
ptr += klen
805+
vlen = ulong_unpack(view[ptr:ptr + 4])
806+
ptr += 4
807+
if vlen == -1:
808+
v = None
809+
else:
810+
v = bytes(view[ptr:ptr + vlen]).decode()
811+
ptr += vlen
812+
813+
result[k] = v
814+
815+
return result
816+
817+
def hstore_encoder(obj):
818+
buffer = bytearray(ulong_pack(len(obj)))
819+
820+
for k, v in obj.items():
821+
kenc = k.encode()
822+
buffer += ulong_pack(len(kenc)) + kenc
823+
824+
if v is None:
825+
buffer += b'\xFF\xFF\xFF\xFF' # -1
826+
else:
827+
venc = v.encode()
828+
buffer += ulong_pack(len(venc)) + venc
829+
830+
return buffer
831+
832+
try:
833+
await self.con.set_type_codec('hstore', encoder=hstore_encoder,
834+
decoder=hstore_decoder,
835+
binary=True)
836+
837+
st = await self.con.prepare('''
838+
SELECT $1::hstore AS result
839+
''')
840+
841+
res = await st.fetchrow({'ham': 'spam'})
842+
res = res['result']
843+
844+
self.assertEqual(res, {'ham': 'spam'})
845+
846+
pt = st.get_parameters()
847+
self.assertTrue(isinstance(pt, tuple))
848+
self.assertEqual(len(pt), 1)
849+
self.assertEqual(pt[0].name, 'hstore')
850+
self.assertEqual(pt[0].kind, 'scalar')
851+
self.assertEqual(pt[0].schema, 'public')
852+
853+
at = st.get_attributes()
854+
self.assertTrue(isinstance(at, tuple))
855+
self.assertEqual(len(at), 1)
856+
self.assertEqual(at[0].name, 'result')
857+
self.assertEqual(at[0].type, pt[0])
858+
859+
finally:
860+
await self.con.execute('''
861+
DROP EXTENSION hstore
862+
''')

0 commit comments

Comments
 (0)