Skip to content

Commit 6d8d482

Browse files
committed
Fixes for pyright
1 parent 9718afc commit 6d8d482

16 files changed

+150
-96
lines changed

asyncpg/__init__.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,5 @@
1717
from ._version import __version__ # NOQA
1818

1919

20-
__all__ = (
21-
'connect',
22-
'create_pool',
23-
'Pool',
24-
'Record',
25-
'Connection'
26-
) + exceptions.__all__
20+
__all__ = ['connect', 'create_pool', 'Pool', 'Record', 'Connection']
21+
__all__ += exceptions.__all__

asyncpg/cluster.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,12 @@
1818
import textwrap
1919
import time
2020
import typing
21-
import typing_extensions
21+
22+
# Work around https://github.com/microsoft/pyright/issues/3012
23+
if sys.version_info >= (3, 8):
24+
from typing import Final
25+
else:
26+
from typing_extensions import Final
2227

2328
import asyncpg
2429
from asyncpg import compat
@@ -35,7 +40,7 @@ class _ConnectionSpec(compat.TypedDict):
3540
port: str
3641

3742

38-
_system: typing_extensions.Final = platform.uname().system
43+
_system: Final = platform.uname().system
3944

4045
if _system == 'Windows':
4146
def platform_exe(name: str) -> str:
@@ -338,7 +343,7 @@ def reset_wal(self, *, oid: typing.Optional[int] = None,
338343
raise ClusterError(
339344
'cannot modify WAL status: cluster is running')
340345

341-
opts = []
346+
opts: typing.List[str] = []
342347
if oid is not None:
343348
opts.extend(['-o', str(oid)])
344349
if xid is not None:
@@ -526,14 +531,15 @@ def _test_connection(self, timeout: int = 60) -> str:
526531
continue
527532

528533
try:
529-
con = loop.run_until_complete(
530-
asyncpg.connect( # type: ignore[arg-type] # noqa: E501
531-
database='postgres',
532-
user='postgres',
533-
timeout=5, loop=loop,
534-
**self._connection_addr
534+
con: 'connection.Connection[typing.Any]' = \
535+
loop.run_until_complete(
536+
asyncpg.connect( # type: ignore[arg-type] # noqa: E501
537+
database='postgres',
538+
user='postgres',
539+
timeout=5, loop=loop,
540+
**self._connection_addr
541+
)
535542
)
536-
)
537543
except (OSError, asyncio.TimeoutError,
538544
exceptions.CannotConnectNowError,
539545
exceptions.PostgresConnectionError):
@@ -562,7 +568,7 @@ def _run_pg_config(self, pg_config_path: str) -> typing.Dict[str, str]:
562568
'pg_config exited with status {:d}: {!r}'.format(
563569
process.returncode, stderr))
564570
else:
565-
config = {}
571+
config: typing.Dict[str, str] = {}
566572

567573
for line in stdout.splitlines():
568574
k, eq, v = line.decode('utf-8').partition('=')

asyncpg/compat.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import pathlib
1010
import sys
1111
import typing
12-
import typing_extensions
1312

1413

1514
if sys.version_info >= (3, 8):
@@ -27,7 +26,6 @@
2726

2827

2928
_T = typing.TypeVar('_T')
30-
PY_37: typing_extensions.Final = sys.version_info >= (3, 7)
3129

3230

3331
if sys.platform == 'win32':

asyncpg/connect_utils.py

Lines changed: 42 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,16 @@
2020
import sys
2121
import time
2222
import typing
23-
import typing_extensions
2423
import urllib.parse
2524
import warnings
2625
import inspect
2726

27+
# Work around https://github.com/microsoft/pyright/issues/3012
28+
if sys.version_info >= (3, 8):
29+
from typing import Final
30+
else:
31+
from typing_extensions import Final
32+
2833
from . import compat
2934
from . import exceptions
3035
from . import protocol
@@ -37,10 +42,13 @@
3742
bound='connection.Connection[typing.Any]'
3843
)
3944
_Protocol = typing.TypeVar('_Protocol', bound='protocol.Protocol[typing.Any]')
45+
_AsyncProtocol = typing.TypeVar(
46+
'_AsyncProtocol', bound='asyncio.protocols.Protocol'
47+
)
4048
_Record = typing.TypeVar('_Record', bound=protocol.Record)
4149
_SSLMode = typing.TypeVar('_SSLMode', bound='SSLMode')
4250

43-
_TPTupleType = typing.Tuple[asyncio.WriteTransport, _Protocol]
51+
_TPTupleType = typing.Tuple[asyncio.WriteTransport, _AsyncProtocol]
4452
AddrType = typing.Union[typing.Tuple[str, int], str]
4553
SSLStringValues = compat.Literal[
4654
'disable', 'prefer', 'allow', 'require', 'verify-ca', 'verify-full'
@@ -76,7 +84,12 @@ def parse(
7684

7785
class _ConnectionParameters(typing.NamedTuple):
7886
user: str
79-
password: typing.Optional[str]
87+
password: typing.Union[
88+
str,
89+
typing.Callable[[], str],
90+
typing.Callable[[], typing.Awaitable[str]],
91+
None
92+
]
8093
database: str
8194
ssl: typing.Optional[_ParsedSSLType]
8295
sslmode: typing.Optional[SSLMode]
@@ -92,8 +105,8 @@ class _ClientConfiguration(typing.NamedTuple):
92105
max_cacheable_statement_size: int
93106

94107

95-
_system: typing_extensions.Final = platform.uname().system
96-
PGPASSFILE: typing_extensions.Final = (
108+
_system: Final = platform.uname().system
109+
PGPASSFILE: Final = (
97110
'pgpass.conf' if _system == 'Windows' else '.pgpass'
98111
)
99112

@@ -102,7 +115,7 @@ def _read_password_file(
102115
passfile: pathlib.Path
103116
) -> typing.List[typing.Tuple[str, ...]]:
104117

105-
passtab = []
118+
passtab: typing.List[typing.Tuple[str, ...]] = []
106119

107120
try:
108121
if not passfile.exists():
@@ -295,7 +308,8 @@ def _parse_connect_dsn_and_args(*, dsn: typing.Optional[str],
295308
ssl: typing.Optional[SSLType],
296309
direct_tls: bool,
297310
connect_timeout: float,
298-
server_settings: typing.Dict[str, str]) \
311+
server_settings: typing.Optional[
312+
typing.Dict[str, str]]) \
299313
-> typing.Tuple[typing.List[typing.Union[typing.Tuple[str, int], str]],
300314
_ConnectionParameters]:
301315
# `auth_hosts` is the version of host information for the purposes
@@ -672,7 +686,8 @@ def _parse_connect_arguments(*, dsn: typing.Optional[str],
672686
max_cacheable_statement_size: int,
673687
ssl: typing.Optional[SSLType],
674688
direct_tls: bool,
675-
server_settings: typing.Dict[str, str]) \
689+
server_settings: typing.Optional[
690+
typing.Dict[str, str]]) \
676691
-> typing.Tuple[typing.List[AddrType], _ConnectionParameters,
677692
_ClientConfiguration]:
678693

@@ -786,7 +801,7 @@ async def _create_ssl_connection(
786801
loop: asyncio.AbstractEventLoop,
787802
ssl_context: ssl_module.SSLContext,
788803
ssl_is_advisory: typing.Optional[bool] = False
789-
) -> _TPTupleType[typing.Union[_Protocol, '_CancelProto']]:
804+
) -> _TPTupleType[typing.Any]:
790805

791806
tr, pr = typing.cast(
792807
typing.Tuple[asyncio.WriteTransport, TLSUpgradeProto],
@@ -915,14 +930,17 @@ async def __connect_addr(
915930
) -> _Connection:
916931
connected = _create_future(loop)
917932

918-
proto_factory = lambda: protocol.Protocol(
933+
proto_factory: typing.Callable[
934+
[], 'protocol.Protocol[_Record]'
935+
] = lambda: protocol.Protocol(
919936
addr, connected, params, record_class, loop)
920937

921938
if isinstance(addr, str):
922939
# UNIX socket
923940
connector = typing.cast(
924-
typing.Coroutine[typing.Any, None,
925-
_TPTupleType['protocol.Protocol[_Record]']],
941+
typing.Coroutine[
942+
typing.Any, None, _TPTupleType['protocol.Protocol[_Record]']
943+
],
926944
loop.create_unix_connection(proto_factory, addr)
927945
)
928946

@@ -939,9 +957,11 @@ async def __connect_addr(
939957
ssl_is_advisory=params.sslmode == SSLMode.prefer)
940958
else:
941959
connector = typing.cast(
942-
typing.Coroutine[typing.Any, None,
943-
_TPTupleType['protocol.Protocol[_Record]']],
944-
loop.create_connection(proto_factory, *addr))
960+
typing.Coroutine[
961+
typing.Any, None, _TPTupleType['protocol.Protocol[_Record]']
962+
],
963+
loop.create_connection(proto_factory, *addr)
964+
)
945965

946966
connector_future = asyncio.ensure_future(connector)
947967
before = time.monotonic()
@@ -1043,13 +1063,12 @@ async def _cancel(*, loop: asyncio.AbstractEventLoop,
10431063
params: _ConnectionParameters,
10441064
backend_pid: int, backend_secret: str) -> None:
10451065

1046-
proto_factory = lambda: _CancelProto(loop)
1066+
proto_factory: typing.Callable[
1067+
[], _CancelProto
1068+
] = lambda: _CancelProto(loop)
10471069

10481070
if isinstance(addr, str):
1049-
tr, pr = typing.cast(
1050-
typing.Tuple[asyncio.WriteTransport, _CancelProto],
1051-
await loop.create_unix_connection(proto_factory, addr)
1052-
)
1071+
tr, pr = await loop.create_unix_connection(proto_factory, addr)
10531072
else:
10541073
if params.ssl and params.sslmode != SSLMode.allow:
10551074
tr, pr = await _create_ssl_connection(
@@ -1059,17 +1078,15 @@ async def _cancel(*, loop: asyncio.AbstractEventLoop,
10591078
ssl_context=params.ssl,
10601079
ssl_is_advisory=params.sslmode == SSLMode.prefer)
10611080
else:
1062-
tr, pr = typing.cast(
1063-
typing.Tuple[asyncio.WriteTransport, _CancelProto],
1064-
await loop.create_connection(proto_factory, *addr))
1081+
tr, pr = await loop.create_connection(proto_factory, *addr)
10651082
_set_nodelay(_get_socket(tr))
10661083

10671084
# Pack a CancelRequest message
10681085
msg = struct.pack('!llll', 16, 80877102, backend_pid, backend_secret)
10691086

10701087
try:
1071-
tr.write(msg)
1072-
await pr.on_disconnect
1088+
typing.cast(typing.Any, tr).write(msg)
1089+
await typing.cast(typing.Any, pr).on_disconnect
10731090
finally:
10741091
tr.close()
10751092

asyncpg/connection.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@
4444
_OtherRecord = typing.TypeVar('_OtherRecord', bound=protocol.Record)
4545
_RecordsType = typing.List[_Record]
4646
_RecordsExtraType = typing.Tuple[_RecordsType[_Record], bytes, bool]
47-
_AnyCallable = typing.Callable[..., typing.Any]
4847

4948
OutputType = typing.Union['os.PathLike[typing.Any]',
5049
typing.BinaryIO,
@@ -1060,7 +1059,8 @@ async def copy_from_table(self, table_name: str, *,
10601059
header: typing.Optional[bool] = None,
10611060
quote: typing.Optional[str] = None,
10621061
escape: typing.Optional[str] = None,
1063-
force_quote: typing.Optional[bool] = None,
1062+
force_quote: typing.Union[
1063+
bool, typing.Iterable[str], None] = None,
10641064
encoding: typing.Optional[str] = None) -> str:
10651065
"""Copy table contents to a file or file-like object.
10661066
@@ -1139,7 +1139,8 @@ async def copy_from_query(self, query: str, *args: typing.Any,
11391139
header: typing.Optional[bool] = None,
11401140
quote: typing.Optional[str] = None,
11411141
escape: typing.Optional[str] = None,
1142-
force_quote: typing.Optional[bool] = None,
1142+
force_quote: typing.Union[
1143+
bool, typing.Iterable[str], None] = None,
11431144
encoding: typing.Optional[str] = None) -> str:
11441145
"""Copy the results of a query to a file or file-like object.
11451146
@@ -1212,9 +1213,12 @@ async def copy_to_table(self, table_name: str, *,
12121213
header: typing.Optional[bool] = None,
12131214
quote: typing.Optional[str] = None,
12141215
escape: typing.Optional[str] = None,
1215-
force_quote: typing.Optional[bool] = None,
1216-
force_not_null: typing.Optional[bool] = None,
1217-
force_null: typing.Optional[bool] = None,
1216+
force_quote: typing.Union[
1217+
bool, typing.Iterable[str], None] = None,
1218+
force_not_null: typing.Union[
1219+
bool, typing.Iterable[str], None] = None,
1220+
force_null: typing.Union[
1221+
bool, typing.Iterable[str], None] = None,
12181222
encoding: typing.Optional[str] = None) -> str:
12191223
"""Copy data to the specified table.
12201224
@@ -1390,13 +1394,16 @@ def _format_copy_opts(self, *,
13901394
header: typing.Optional[bool] = None,
13911395
quote: typing.Optional[str] = None,
13921396
escape: typing.Optional[str] = None,
1393-
force_quote: typing.Optional[bool] = None,
1394-
force_not_null: typing.Optional[bool] = None,
1395-
force_null: typing.Optional[bool] = None,
1397+
force_quote: typing.Union[
1398+
bool, typing.Iterable[str], None] = None,
1399+
force_not_null: typing.Union[
1400+
bool, typing.Iterable[str], None] = None,
1401+
force_null: typing.Union[
1402+
bool, typing.Iterable[str], None] = None,
13961403
encoding: typing.Optional[str] = None) -> str:
13971404
kwargs = dict(locals())
13981405
kwargs.pop('self')
1399-
opts = []
1406+
opts: typing.List[str] = []
14001407

14011408
if force_quote is not None and isinstance(force_quote, bool):
14021409
kwargs.pop('force_quote')
@@ -1975,7 +1982,7 @@ def _get_reset_query(self) -> str:
19751982

19761983
caps = self._server_caps
19771984

1978-
_reset_query = []
1985+
_reset_query: typing.List[str] = []
19791986
if caps.advisory_locks:
19801987
_reset_query.append('SELECT pg_advisory_unlock_all();')
19811988
if caps.sql_close_all:
@@ -2337,7 +2344,8 @@ async def __execute(
23372344
) -> typing.Tuple[
23382345
typing.Any, '_cprotocol.PreparedStatementState[typing.Any]'
23392346
]:
2340-
executor = lambda stmt, timeout: self._protocol.bind_execute(
2347+
executor: Executor[_OtherRecord] = \
2348+
lambda stmt, timeout: self._protocol.bind_execute(
23412349
stmt, args, '', limit, return_status, timeout)
23422350
timeout = self._protocol._get_timeout(timeout)
23432351
return await self._do_execute(
@@ -2351,7 +2359,8 @@ async def __execute(
23512359
async def _executemany(self, query: str,
23522360
args: typing.Iterable[typing.Sequence[typing.Any]],
23532361
timeout: typing.Optional[float]) -> None:
2354-
executor = lambda stmt, timeout: self._protocol.bind_execute_many(
2362+
executor: Executor[_Record] = \
2363+
lambda stmt, timeout: self._protocol.bind_execute_many(
23552364
stmt, args, '', timeout)
23562365
timeout = self._protocol._get_timeout(timeout)
23572366
with self._stmt_exclusive_section:
@@ -3222,7 +3231,7 @@ def _extract_stack(limit: int = 10) -> str:
32223231
finally:
32233232
del frame
32243233

3225-
apg_path = asyncpg.__path__[0]
3234+
apg_path = list(asyncpg.__path__)[0]
32263235
i = 0
32273236
while i < len(stack) and stack[i][0].startswith(apg_path):
32283237
i += 1

asyncpg/connresource.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,14 @@
1616
from . import connection as _connection
1717

1818

19-
_ConnectionResource = typing.TypeVar('_ConnectionResource',
20-
bound='ConnectionResource')
2119
_Callable = typing.TypeVar('_Callable', bound=typing.Callable[..., typing.Any])
2220

2321

2422
def guarded(meth: _Callable) -> _Callable:
2523
"""A decorator to add a sanity check to ConnectionResource methods."""
2624

2725
@functools.wraps(meth)
28-
def _check(self: _ConnectionResource,
26+
def _check(self: 'ConnectionResource',
2927
*args: typing.Any,
3028
**kwargs: typing.Any) -> typing.Any:
3129
self._check_conn_validity(meth.__name__)

asyncpg/cursor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ async def _exec(self, n: int,
235235
return buffer
236236

237237
def __repr__(self) -> str:
238-
attrs = []
238+
attrs: typing.List[str] = []
239239
if self._exhausted:
240240
attrs.append('exhausted')
241241
attrs.append('') # to separate from id

0 commit comments

Comments
 (0)