2020import sys
2121import time
2222import typing
23- import typing_extensions
2423import urllib .parse
2524import warnings
2625import inspect
2726
27+ # Work around https://github.com/microsoft/pyright/issues/3012
28+ if sys .version_info >= (3 , 8 ):
29+ from typing import Final
30+ else :
31+ from typing_extensions import Final
32+
2833from . import compat
2934from . import exceptions
3035from . import protocol
3742 bound = 'connection.Connection[typing.Any]'
3843)
3944_Protocol = typing .TypeVar ('_Protocol' , bound = 'protocol.Protocol[typing.Any]' )
45+ _AsyncProtocol = typing .TypeVar (
46+ '_AsyncProtocol' , bound = 'asyncio.protocols.Protocol'
47+ )
4048_Record = typing .TypeVar ('_Record' , bound = protocol .Record )
4149_SSLMode = typing .TypeVar ('_SSLMode' , bound = 'SSLMode' )
4250
43- _TPTupleType = typing .Tuple [asyncio .WriteTransport , _Protocol ]
51+ _TPTupleType = typing .Tuple [asyncio .WriteTransport , _AsyncProtocol ]
4452AddrType = typing .Union [typing .Tuple [str , int ], str ]
4553SSLStringValues = compat .Literal [
4654 'disable' , 'prefer' , 'allow' , 'require' , 'verify-ca' , 'verify-full'
@@ -76,7 +84,12 @@ def parse(
7684
7785class _ConnectionParameters (typing .NamedTuple ):
7886 user : str
79- password : typing .Optional [str ]
87+ password : typing .Union [
88+ str ,
89+ typing .Callable [[], str ],
90+ typing .Callable [[], typing .Awaitable [str ]],
91+ None
92+ ]
8093 database : str
8194 ssl : typing .Optional [_ParsedSSLType ]
8295 sslmode : typing .Optional [SSLMode ]
@@ -92,8 +105,8 @@ class _ClientConfiguration(typing.NamedTuple):
92105 max_cacheable_statement_size : int
93106
94107
95- _system : typing_extensions . Final = platform .uname ().system
96- PGPASSFILE : typing_extensions . Final = (
108+ _system : Final = platform .uname ().system
109+ PGPASSFILE : Final = (
97110 'pgpass.conf' if _system == 'Windows' else '.pgpass'
98111)
99112
@@ -102,7 +115,7 @@ def _read_password_file(
102115 passfile : pathlib .Path
103116) -> typing .List [typing .Tuple [str , ...]]:
104117
105- passtab = []
118+ passtab : typing . List [ typing . Tuple [ str , ...]] = []
106119
107120 try :
108121 if not passfile .exists ():
@@ -295,7 +308,8 @@ def _parse_connect_dsn_and_args(*, dsn: typing.Optional[str],
295308 ssl : typing .Optional [SSLType ],
296309 direct_tls : bool ,
297310 connect_timeout : float ,
298- server_settings : typing .Dict [str , str ]) \
311+ server_settings : typing .Optional [
312+ typing .Dict [str , str ]]) \
299313 -> typing .Tuple [typing .List [typing .Union [typing .Tuple [str , int ], str ]],
300314 _ConnectionParameters ]:
301315 # `auth_hosts` is the version of host information for the purposes
@@ -672,7 +686,8 @@ def _parse_connect_arguments(*, dsn: typing.Optional[str],
672686 max_cacheable_statement_size : int ,
673687 ssl : typing .Optional [SSLType ],
674688 direct_tls : bool ,
675- server_settings : typing .Dict [str , str ]) \
689+ server_settings : typing .Optional [
690+ typing .Dict [str , str ]]) \
676691 -> typing .Tuple [typing .List [AddrType ], _ConnectionParameters ,
677692 _ClientConfiguration ]:
678693
@@ -786,7 +801,7 @@ async def _create_ssl_connection(
786801 loop : asyncio .AbstractEventLoop ,
787802 ssl_context : ssl_module .SSLContext ,
788803 ssl_is_advisory : typing .Optional [bool ] = False
789- ) -> _TPTupleType [typing .Union [ _Protocol , '_CancelProto' ] ]:
804+ ) -> _TPTupleType [typing .Any ]:
790805
791806 tr , pr = typing .cast (
792807 typing .Tuple [asyncio .WriteTransport , TLSUpgradeProto ],
@@ -915,14 +930,17 @@ async def __connect_addr(
915930) -> _Connection :
916931 connected = _create_future (loop )
917932
918- proto_factory = lambda : protocol .Protocol (
933+ proto_factory : typing .Callable [
934+ [], 'protocol.Protocol[_Record]'
935+ ] = lambda : protocol .Protocol (
919936 addr , connected , params , record_class , loop )
920937
921938 if isinstance (addr , str ):
922939 # UNIX socket
923940 connector = typing .cast (
924- typing .Coroutine [typing .Any , None ,
925- _TPTupleType ['protocol.Protocol[_Record]' ]],
941+ typing .Coroutine [
942+ typing .Any , None , _TPTupleType ['protocol.Protocol[_Record]' ]
943+ ],
926944 loop .create_unix_connection (proto_factory , addr )
927945 )
928946
@@ -939,9 +957,11 @@ async def __connect_addr(
939957 ssl_is_advisory = params .sslmode == SSLMode .prefer )
940958 else :
941959 connector = typing .cast (
942- typing .Coroutine [typing .Any , None ,
943- _TPTupleType ['protocol.Protocol[_Record]' ]],
944- loop .create_connection (proto_factory , * addr ))
960+ typing .Coroutine [
961+ typing .Any , None , _TPTupleType ['protocol.Protocol[_Record]' ]
962+ ],
963+ loop .create_connection (proto_factory , * addr )
964+ )
945965
946966 connector_future = asyncio .ensure_future (connector )
947967 before = time .monotonic ()
@@ -1043,13 +1063,12 @@ async def _cancel(*, loop: asyncio.AbstractEventLoop,
10431063 params : _ConnectionParameters ,
10441064 backend_pid : int , backend_secret : str ) -> None :
10451065
1046- proto_factory = lambda : _CancelProto (loop )
1066+ proto_factory : typing .Callable [
1067+ [], _CancelProto
1068+ ] = lambda : _CancelProto (loop )
10471069
10481070 if isinstance (addr , str ):
1049- tr , pr = typing .cast (
1050- typing .Tuple [asyncio .WriteTransport , _CancelProto ],
1051- await loop .create_unix_connection (proto_factory , addr )
1052- )
1071+ tr , pr = await loop .create_unix_connection (proto_factory , addr )
10531072 else :
10541073 if params .ssl and params .sslmode != SSLMode .allow :
10551074 tr , pr = await _create_ssl_connection (
@@ -1059,17 +1078,15 @@ async def _cancel(*, loop: asyncio.AbstractEventLoop,
10591078 ssl_context = params .ssl ,
10601079 ssl_is_advisory = params .sslmode == SSLMode .prefer )
10611080 else :
1062- tr , pr = typing .cast (
1063- typing .Tuple [asyncio .WriteTransport , _CancelProto ],
1064- await loop .create_connection (proto_factory , * addr ))
1081+ tr , pr = await loop .create_connection (proto_factory , * addr )
10651082 _set_nodelay (_get_socket (tr ))
10661083
10671084 # Pack a CancelRequest message
10681085 msg = struct .pack ('!llll' , 16 , 80877102 , backend_pid , backend_secret )
10691086
10701087 try :
1071- tr .write (msg )
1072- await pr .on_disconnect
1088+ typing . cast ( typing . Any , tr ) .write (msg )
1089+ await typing . cast ( typing . Any , pr ) .on_disconnect
10731090 finally :
10741091 tr .close ()
10751092
0 commit comments