20
20
import sys
21
21
import time
22
22
import typing
23
- import typing_extensions
24
23
import urllib .parse
25
24
import warnings
26
25
import inspect
27
26
27
+ # Work around https://github.com/microsoft/pyright/issues/3012
28
+ if sys .version_info >= (3 , 8 ):
29
+ from typing import Final
30
+ else :
31
+ from typing_extensions import Final
32
+
28
33
from . import compat
29
34
from . import exceptions
30
35
from . import protocol
37
42
bound = 'connection.Connection[typing.Any]'
38
43
)
39
44
_Protocol = typing .TypeVar ('_Protocol' , bound = 'protocol.Protocol[typing.Any]' )
45
+ _AsyncProtocol = typing .TypeVar (
46
+ '_AsyncProtocol' , bound = 'asyncio.protocols.Protocol'
47
+ )
40
48
_Record = typing .TypeVar ('_Record' , bound = protocol .Record )
41
49
_SSLMode = typing .TypeVar ('_SSLMode' , bound = 'SSLMode' )
42
50
43
- _TPTupleType = typing .Tuple [asyncio .WriteTransport , _Protocol ]
51
+ _TPTupleType = typing .Tuple [asyncio .WriteTransport , _AsyncProtocol ]
44
52
AddrType = typing .Union [typing .Tuple [str , int ], str ]
45
53
SSLStringValues = compat .Literal [
46
54
'disable' , 'prefer' , 'allow' , 'require' , 'verify-ca' , 'verify-full'
@@ -76,7 +84,12 @@ def parse(
76
84
77
85
class _ConnectionParameters (typing .NamedTuple ):
78
86
user : str
79
- password : typing .Optional [str ]
87
+ password : typing .Union [
88
+ str ,
89
+ typing .Callable [[], str ],
90
+ typing .Callable [[], typing .Awaitable [str ]],
91
+ None
92
+ ]
80
93
database : str
81
94
ssl : typing .Optional [_ParsedSSLType ]
82
95
sslmode : typing .Optional [SSLMode ]
@@ -92,8 +105,8 @@ class _ClientConfiguration(typing.NamedTuple):
92
105
max_cacheable_statement_size : int
93
106
94
107
95
- _system : typing_extensions . Final = platform .uname ().system
96
- PGPASSFILE : typing_extensions . Final = (
108
+ _system : Final = platform .uname ().system
109
+ PGPASSFILE : Final = (
97
110
'pgpass.conf' if _system == 'Windows' else '.pgpass'
98
111
)
99
112
@@ -102,7 +115,7 @@ def _read_password_file(
102
115
passfile : pathlib .Path
103
116
) -> typing .List [typing .Tuple [str , ...]]:
104
117
105
- passtab = []
118
+ passtab : typing . List [ typing . Tuple [ str , ...]] = []
106
119
107
120
try :
108
121
if not passfile .exists ():
@@ -295,7 +308,8 @@ def _parse_connect_dsn_and_args(*, dsn: typing.Optional[str],
295
308
ssl : typing .Optional [SSLType ],
296
309
direct_tls : bool ,
297
310
connect_timeout : float ,
298
- server_settings : typing .Dict [str , str ]) \
311
+ server_settings : typing .Optional [
312
+ typing .Dict [str , str ]]) \
299
313
-> typing .Tuple [typing .List [typing .Union [typing .Tuple [str , int ], str ]],
300
314
_ConnectionParameters ]:
301
315
# `auth_hosts` is the version of host information for the purposes
@@ -672,7 +686,8 @@ def _parse_connect_arguments(*, dsn: typing.Optional[str],
672
686
max_cacheable_statement_size : int ,
673
687
ssl : typing .Optional [SSLType ],
674
688
direct_tls : bool ,
675
- server_settings : typing .Dict [str , str ]) \
689
+ server_settings : typing .Optional [
690
+ typing .Dict [str , str ]]) \
676
691
-> typing .Tuple [typing .List [AddrType ], _ConnectionParameters ,
677
692
_ClientConfiguration ]:
678
693
@@ -786,7 +801,7 @@ async def _create_ssl_connection(
786
801
loop : asyncio .AbstractEventLoop ,
787
802
ssl_context : ssl_module .SSLContext ,
788
803
ssl_is_advisory : typing .Optional [bool ] = False
789
- ) -> _TPTupleType [typing .Union [ _Protocol , '_CancelProto' ] ]:
804
+ ) -> _TPTupleType [typing .Any ]:
790
805
791
806
tr , pr = typing .cast (
792
807
typing .Tuple [asyncio .WriteTransport , TLSUpgradeProto ],
@@ -915,14 +930,17 @@ async def __connect_addr(
915
930
) -> _Connection :
916
931
connected = _create_future (loop )
917
932
918
- proto_factory = lambda : protocol .Protocol (
933
+ proto_factory : typing .Callable [
934
+ [], 'protocol.Protocol[_Record]'
935
+ ] = lambda : protocol .Protocol (
919
936
addr , connected , params , record_class , loop )
920
937
921
938
if isinstance (addr , str ):
922
939
# UNIX socket
923
940
connector = typing .cast (
924
- typing .Coroutine [typing .Any , None ,
925
- _TPTupleType ['protocol.Protocol[_Record]' ]],
941
+ typing .Coroutine [
942
+ typing .Any , None , _TPTupleType ['protocol.Protocol[_Record]' ]
943
+ ],
926
944
loop .create_unix_connection (proto_factory , addr )
927
945
)
928
946
@@ -939,9 +957,11 @@ async def __connect_addr(
939
957
ssl_is_advisory = params .sslmode == SSLMode .prefer )
940
958
else :
941
959
connector = typing .cast (
942
- typing .Coroutine [typing .Any , None ,
943
- _TPTupleType ['protocol.Protocol[_Record]' ]],
944
- loop .create_connection (proto_factory , * addr ))
960
+ typing .Coroutine [
961
+ typing .Any , None , _TPTupleType ['protocol.Protocol[_Record]' ]
962
+ ],
963
+ loop .create_connection (proto_factory , * addr )
964
+ )
945
965
946
966
connector_future = asyncio .ensure_future (connector )
947
967
before = time .monotonic ()
@@ -1043,13 +1063,12 @@ async def _cancel(*, loop: asyncio.AbstractEventLoop,
1043
1063
params : _ConnectionParameters ,
1044
1064
backend_pid : int , backend_secret : str ) -> None :
1045
1065
1046
- proto_factory = lambda : _CancelProto (loop )
1066
+ proto_factory : typing .Callable [
1067
+ [], _CancelProto
1068
+ ] = lambda : _CancelProto (loop )
1047
1069
1048
1070
if isinstance (addr , str ):
1049
- tr , pr = typing .cast (
1050
- typing .Tuple [asyncio .WriteTransport , _CancelProto ],
1051
- await loop .create_unix_connection (proto_factory , addr )
1052
- )
1071
+ tr , pr = await loop .create_unix_connection (proto_factory , addr )
1053
1072
else :
1054
1073
if params .ssl and params .sslmode != SSLMode .allow :
1055
1074
tr , pr = await _create_ssl_connection (
@@ -1059,17 +1078,15 @@ async def _cancel(*, loop: asyncio.AbstractEventLoop,
1059
1078
ssl_context = params .ssl ,
1060
1079
ssl_is_advisory = params .sslmode == SSLMode .prefer )
1061
1080
else :
1062
- tr , pr = typing .cast (
1063
- typing .Tuple [asyncio .WriteTransport , _CancelProto ],
1064
- await loop .create_connection (proto_factory , * addr ))
1081
+ tr , pr = await loop .create_connection (proto_factory , * addr )
1065
1082
_set_nodelay (_get_socket (tr ))
1066
1083
1067
1084
# Pack a CancelRequest message
1068
1085
msg = struct .pack ('!llll' , 16 , 80877102 , backend_pid , backend_secret )
1069
1086
1070
1087
try :
1071
- tr .write (msg )
1072
- await pr .on_disconnect
1088
+ typing . cast ( typing . Any , tr ) .write (msg )
1089
+ await typing . cast ( typing . Any , pr ) .on_disconnect
1073
1090
finally :
1074
1091
tr .close ()
1075
1092
0 commit comments