10
10
import ipaddress
11
11
import math
12
12
import random
13
+ import struct
13
14
import uuid
14
15
15
16
import asyncpg
@@ -31,6 +32,9 @@ def _timezone(offset):
31
32
32
33
33
34
type_samples = [
35
+ ('bool' , 'bool' , (
36
+ True , False ,
37
+ )),
34
38
('smallint' , 'int2' , (
35
39
- 2 ** 15 + 1 , 2 ** 15 - 1 ,
36
40
- 1 , 0 , 1 ,
@@ -132,7 +136,8 @@ def _timezone(offset):
132
136
bytes (range (255 , - 1 , - 1 )),
133
137
b'\x00 \x00 ' ,
134
138
b'foo' ,
135
- b'f' * 1024 * 1024
139
+ b'f' * 1024 * 1024 ,
140
+ dict (input = bytearray (b'\x02 \x01 ' ), output = b'\x02 \x01 ' ),
136
141
)),
137
142
('text' , 'text' , (
138
143
'' ,
@@ -156,6 +161,7 @@ def _timezone(offset):
156
161
datetime .date (2000 , 1 , 1 ),
157
162
datetime .date (500 , 1 , 1 ),
158
163
datetime .date (1 , 1 , 1 ),
164
+ infinity_date ,
159
165
]),
160
166
('time' , 'time' , [
161
167
datetime .time (12 , 15 , 20 ),
@@ -191,7 +197,9 @@ def _timezone(offset):
191
197
]),
192
198
('uuid' , 'uuid' , [
193
199
uuid .UUID ('38a4ff5a-3a56-11e6-a6c2-c8f73323c6d4' ),
194
- uuid .UUID ('00000000-0000-0000-0000-000000000000' )
200
+ uuid .UUID ('00000000-0000-0000-0000-000000000000' ),
201
+ {'input' : '00000000-0000-0000-0000-000000000000' ,
202
+ 'output' : uuid .UUID ('00000000-0000-0000-0000-000000000000' )}
195
203
]),
196
204
('uuid[]' , 'uuid[]' , [
197
205
(uuid .UUID ('38a4ff5a-3a56-11e6-a6c2-c8f73323c6d4' ),
@@ -294,11 +302,21 @@ def _timezone(offset):
294
302
asyncpg .BitString (),
295
303
asyncpg .BitString .frombytes (b'\x00 ' , bitlength = 3 ),
296
304
asyncpg .BitString ('0000 0000 1' ),
305
+ dict (input = b'\x01 ' , output = asyncpg .BitString ('0000 0001' )),
306
+ dict (input = bytearray (b'\x02 ' ), output = asyncpg .BitString ('0000 0010' )),
297
307
]),
298
308
('path' , 'path' , [
299
309
asyncpg .Path (asyncpg .Point (0.0 , 0.0 ), asyncpg .Point (1.0 , 1.0 )),
300
310
asyncpg .Path (asyncpg .Point (0.0 , 0.0 ), asyncpg .Point (1.0 , 1.0 ),
301
311
is_closed = True ),
312
+ dict (input = ((0.0 , 0.0 ), (1.0 , 1.0 )),
313
+ output = asyncpg .Path (asyncpg .Point (0.0 , 0.0 ),
314
+ asyncpg .Point (1.0 , 1.0 ),
315
+ is_closed = True )),
316
+ dict (input = [(0.0 , 0.0 ), (1.0 , 1.0 )],
317
+ output = asyncpg .Path (asyncpg .Point (0.0 , 0.0 ),
318
+ asyncpg .Point (1.0 , 1.0 ),
319
+ is_closed = False )),
302
320
]),
303
321
('point' , 'point' , [
304
322
asyncpg .Point (0.0 , 0.0 ),
@@ -334,22 +352,28 @@ async def test_standard_codecs(self):
334
352
335
353
for sample in sample_data :
336
354
with self .subTest (sample = sample , typname = typname ):
337
- rsample = await st .fetchval (sample )
355
+ if isinstance (sample , dict ):
356
+ inputval = sample ['input' ]
357
+ outputval = sample ['output' ]
358
+ else :
359
+ inputval = outputval = sample
360
+
361
+ result = await st .fetchval (inputval )
338
362
err_msg = (
339
- "failed to return {} object data as-is; "
340
- "gave {!r}, received {!r}" .format (
341
- typname , sample , rsample ))
363
+ "unexpected result for {} when passing {!r}: "
364
+ "received {!r}, expected {!r}" .format (
365
+ typname , inputval , result , outputval ))
342
366
343
367
if typname .startswith ('float' ):
344
- if math .isnan (sample ):
345
- if not math .isnan (rsample ):
368
+ if math .isnan (outputval ):
369
+ if not math .isnan (result ):
346
370
self .fail (err_msg )
347
371
else :
348
372
self .assertTrue (
349
- math .isclose (rsample , sample , rel_tol = 1e-6 ),
373
+ math .isclose (result , outputval , rel_tol = 1e-6 ),
350
374
err_msg )
351
375
else :
352
- self .assertEqual (rsample , sample , err_msg )
376
+ self .assertEqual (result , outputval , err_msg )
353
377
354
378
with self .subTest (sample = None , typname = typname ):
355
379
# Test that None is handled for all types.
@@ -369,10 +393,9 @@ async def test_all_builtin_types_handled(self):
369
393
'core type {} ({}) is unhandled' .format (typename , oid ))
370
394
371
395
async def test_void (self ):
372
- stmt = await self .con .prepare ('select pg_sleep(0)' )
373
- self .assertIsNone (await stmt .fetchval ())
374
-
375
- await self .con .fetchval ('select now($1::void)' , None )
396
+ res = await self .con .fetchval ('select pg_sleep(0)' )
397
+ self .assertIsNone (res )
398
+ await self .con .fetchval ('select now($1::void)' , '' )
376
399
377
400
def test_bitstring (self ):
378
401
bitlen = random .randint (0 , 1000 )
@@ -424,6 +447,10 @@ async def test_invalid_input(self):
424
447
32768 ,
425
448
- 32768
426
449
]),
450
+ ('float4' , ValueError , 'float value too large' , [
451
+ 4.1 * 10 ** 40 ,
452
+ - 4.1 * 10 ** 40 ,
453
+ ]),
427
454
('int4' , TypeError , 'an integer is required' , [
428
455
'2' ,
429
456
'aa' ,
@@ -452,7 +479,11 @@ async def test_arrays(self):
452
479
(
453
480
r"SELECT '{{{{{{1}}}}}}'::int[]" ,
454
481
((((((1 ,),),),),),)
455
- )
482
+ ),
483
+ (
484
+ r"SELECT '{1, 2, NULL}'::int[]::anyarray" ,
485
+ (1 , 2 , None )
486
+ ),
456
487
]
457
488
458
489
for sql , expected in cases :
@@ -464,6 +495,7 @@ async def test_arrays(self):
464
495
await self .con .fetchval ("SELECT '{{{{{{{1}}}}}}}'::int[]" )
465
496
466
497
cases = [
498
+ (None ,),
467
499
(1 , 2 , 3 , 4 , 5 , 6 ),
468
500
((1 , 2 ), (4 , 5 ), (6 , 7 )),
469
501
(((1 ,), (2 ,)), ((4 ,), (5 ,)), ((None ,), (7 ,))),
@@ -559,6 +591,10 @@ async def test_composites(self):
559
591
self .assertEqual (at [0 ].type .name , 'test_composite' )
560
592
self .assertEqual (at [0 ].type .kind , 'composite' )
561
593
594
+ res = await self .con .fetchval ('''
595
+ SELECT $1::test_composite
596
+ ''' , res )
597
+
562
598
finally :
563
599
await self .con .execute ('DROP TYPE test_composite' )
564
600
@@ -645,13 +681,29 @@ async def test_extra_codec_alias(self):
645
681
await self .con .set_builtin_type_codec (
646
682
'hstore' , codec_name = 'pg_contrib.hstore' )
647
683
684
+ cases = [
685
+ {'ham' : 'spam' , 'nada' : None },
686
+ {}
687
+ ]
688
+
648
689
st = await self .con .prepare ('''
649
690
SELECT $1::hstore AS result
650
691
''' )
651
- res = await st .fetchrow ({'ham' : 'spam' , 'nada' : None })
652
- res = res ['result' ]
653
692
654
- self .assertEqual (res , {'ham' : 'spam' , 'nada' : None })
693
+ for case in cases :
694
+ res = await st .fetchval (case )
695
+ self .assertEqual (res , case )
696
+
697
+ res = await self .con .fetchval ('''
698
+ SELECT $1::hstore AS result
699
+ ''' , (('foo' , 2 ), ('bar' , 3 )))
700
+
701
+ self .assertEqual (res , {'foo' : '2' , 'bar' : '3' })
702
+
703
+ with self .assertRaisesRegex (ValueError , 'null value not allowed' ):
704
+ await self .con .fetchval ('''
705
+ SELECT $1::hstore AS result
706
+ ''' , {None : '1' })
655
707
656
708
finally :
657
709
await self .con .execute ('''
@@ -728,3 +780,83 @@ def hstore_encoder(obj):
728
780
await self .con .execute ('''
729
781
DROP EXTENSION hstore
730
782
''' )
783
+
784
+ async def test_custom_codec_binary (self ):
785
+ """Test encoding/decoding using a custom codec in binary mode."""
786
+ await self .con .execute ('''
787
+ CREATE EXTENSION IF NOT EXISTS hstore
788
+ ''' )
789
+
790
+ longstruct = struct .Struct ('!L' )
791
+ ulong_unpack = lambda b : longstruct .unpack_from (b )[0 ]
792
+ ulong_pack = longstruct .pack
793
+
794
+ def hstore_decoder (data ):
795
+ result = {}
796
+ n = ulong_unpack (data )
797
+ view = memoryview (data )
798
+ ptr = 4
799
+
800
+ for i in range (n ):
801
+ klen = ulong_unpack (view [ptr :ptr + 4 ])
802
+ ptr += 4
803
+ k = bytes (view [ptr :ptr + klen ]).decode ()
804
+ ptr += klen
805
+ vlen = ulong_unpack (view [ptr :ptr + 4 ])
806
+ ptr += 4
807
+ if vlen == - 1 :
808
+ v = None
809
+ else :
810
+ v = bytes (view [ptr :ptr + vlen ]).decode ()
811
+ ptr += vlen
812
+
813
+ result [k ] = v
814
+
815
+ return result
816
+
817
+ def hstore_encoder (obj ):
818
+ buffer = bytearray (ulong_pack (len (obj )))
819
+
820
+ for k , v in obj .items ():
821
+ kenc = k .encode ()
822
+ buffer += ulong_pack (len (kenc )) + kenc
823
+
824
+ if v is None :
825
+ buffer += b'\xFF \xFF \xFF \xFF ' # -1
826
+ else :
827
+ venc = v .encode ()
828
+ buffer += ulong_pack (len (venc )) + venc
829
+
830
+ return buffer
831
+
832
+ try :
833
+ await self .con .set_type_codec ('hstore' , encoder = hstore_encoder ,
834
+ decoder = hstore_decoder ,
835
+ binary = True )
836
+
837
+ st = await self .con .prepare ('''
838
+ SELECT $1::hstore AS result
839
+ ''' )
840
+
841
+ res = await st .fetchrow ({'ham' : 'spam' })
842
+ res = res ['result' ]
843
+
844
+ self .assertEqual (res , {'ham' : 'spam' })
845
+
846
+ pt = st .get_parameters ()
847
+ self .assertTrue (isinstance (pt , tuple ))
848
+ self .assertEqual (len (pt ), 1 )
849
+ self .assertEqual (pt [0 ].name , 'hstore' )
850
+ self .assertEqual (pt [0 ].kind , 'scalar' )
851
+ self .assertEqual (pt [0 ].schema , 'public' )
852
+
853
+ at = st .get_attributes ()
854
+ self .assertTrue (isinstance (at , tuple ))
855
+ self .assertEqual (len (at ), 1 )
856
+ self .assertEqual (at [0 ].name , 'result' )
857
+ self .assertEqual (at [0 ].type , pt [0 ])
858
+
859
+ finally :
860
+ await self .con .execute ('''
861
+ DROP EXTENSION hstore
862
+ ''' )
0 commit comments