Skip to content

Commit a0cc7c3

Browse files
committed
Allow any non-trivial Container to be passed as an array argument
1 parent 64525ea commit a0cc7c3

File tree

4 files changed

+20
-7
lines changed

4 files changed

+20
-7
lines changed

asyncpg/protocol/codecs/array.pyx

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
66

77

8+
from collections.abc import Container as ContainerABC
9+
10+
811
DEF ARRAY_MAXDIM = 6 # defined in postgresql/src/includes/c.h
912

1013

@@ -19,8 +22,13 @@ ctypedef object (*decode_func_ex)(ConnectionSettings settings,
1922
const void *arg)
2023

2124

22-
cdef inline bint _is_array(object obj):
23-
return cpython.PyTuple_Check(obj) or cpython.PyList_Check(obj)
25+
cdef inline bint _is_trivial_container(object obj):
26+
return cpython.PyUnicode_Check(obj) or cpython.PyBytes_Check(obj) or \
27+
PyByteArray_Check(obj) or PyMemoryView_Check(obj)
28+
29+
30+
cdef inline _is_container(object obj):
31+
return not _is_trivial_container(obj) and isinstance(obj, ContainerABC)
2432

2533

2634
cdef _get_array_shape(object obj, int32_t *dims, int32_t *ndims):
@@ -37,7 +45,7 @@ cdef _get_array_shape(object obj, int32_t *dims, int32_t *ndims):
3745
dims[ndims[0] - 1] = mylen
3846

3947
for elem in obj:
40-
if _is_array(elem):
48+
if _is_container(elem):
4149
if elemlen == -2:
4250
elemlen = len(elem)
4351
ndims[0] += 1
@@ -80,9 +88,10 @@ cdef inline array_encode(ConnectionSettings settings, WriteBuffer buf,
8088
int32_t ndims = 1
8189
int32_t i
8290

83-
if not _is_array(obj):
91+
if not _is_container(obj):
8492
raise TypeError(
85-
'list or tuple expected (got type {})'.format(type(obj)))
93+
'a non-trivial iterable expected (got type {!r})'.format(
94+
type(obj).__name__))
8695

8796
_get_array_shape(obj, dims, &ndims)
8897

asyncpg/protocol/protocol.pyx

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ from asyncpg.protocol cimport record
2323
from asyncpg.protocol.python cimport (
2424
PyMem_Malloc, PyMem_Realloc, PyMem_Calloc, PyMem_Free,
2525
PyMemoryView_GET_BUFFER, PyMemoryView_Check,
26-
PyUnicode_AsUTF8AndSize, PyByteArray_AsString)
26+
PyUnicode_AsUTF8AndSize, PyByteArray_AsString,
27+
PyByteArray_Check)
2728

2829
from cpython cimport PyBuffer_FillInfo, PyBytes_AsString
2930

asyncpg/protocol/python.pxd

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,10 @@ cdef extern from "Python.h":
1313
void* PyMem_Calloc(size_t nelem, size_t elsize) # Python >= 3.5!
1414
void PyMem_Free(void *p)
1515

16+
int PyByteArray_Check(object)
17+
1618
int PyMemoryView_Check(object)
1719
Py_buffer *PyMemoryView_GET_BUFFER(object)
20+
1821
char* PyUnicode_AsUTF8AndSize(object unicode, ssize_t *size) except NULL
1922
char* PyByteArray_AsString(object)

tests/test_codecs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -542,7 +542,7 @@ async def test_arrays(self):
542542
"SELECT $1::int[]",
543543
[[1], ['t'], [2]])
544544

545-
with self.assertRaisesRegex(TypeError, 'list or tuple expected'):
545+
with self.assertRaisesRegex(TypeError, 'non-trivial iterable expected'):
546546
await self.con.fetchval(
547547
"SELECT $1::int[]",
548548
1)

0 commit comments

Comments
 (0)