Skip to content

Commit 2ac8eb1

Browse files
committed
Improve type completeness
1 parent b227f2f commit 2ac8eb1

File tree

8 files changed

+43
-23
lines changed

8 files changed

+43
-23
lines changed

asyncpg/cluster.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -724,6 +724,8 @@ def start(self, wait: int = 60, *,
724724

725725

726726
class RunningCluster(Cluster):
727+
conn_spec: _ConnectionSpec
728+
727729
def __init__(self, **kwargs: str) -> None:
728730
self.conn_spec = typing.cast(_ConnectionSpec, kwargs)
729731

asyncpg/connect_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -729,6 +729,12 @@ def _parse_connect_arguments(*, dsn: typing.Optional[str],
729729

730730

731731
class TLSUpgradeProto(asyncio.Protocol):
732+
on_data: 'asyncio.Future[bool]'
733+
host: str
734+
port: int
735+
ssl_context: ssl_module.SSLContext
736+
ssl_is_advisory: typing.Optional[bool]
737+
732738
def __init__(self, loop: typing.Optional[asyncio.AbstractEventLoop],
733739
host: str, port: int, ssl_context: ssl_module.SSLContext,
734740
ssl_is_advisory: typing.Optional[bool]) -> None:

asyncpg/connection.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,12 @@
2929
from . import protocol
3030
from . import serverversion
3131
from . import transaction
32+
from . import types
3233
from . import utils
34+
from .protocol import protocol as _cprotocol
3335

3436
if typing.TYPE_CHECKING:
35-
from .protocol import protocol as _cprotocol
3637
from . import pool_connection_proxy as _pool
37-
from . import types
3838

3939

4040
_Connection = typing.TypeVar('_Connection', bound='Connection[typing.Any]')
@@ -336,7 +336,7 @@ def get_server_pid(self) -> int:
336336
"""Return the PID of the Postgres server the connection is bound to."""
337337
return self._protocol.get_server_pid()
338338

339-
def get_server_version(self) -> 'types.ServerVersion':
339+
def get_server_version(self) -> types.ServerVersion:
340340
"""Return the version of the connected PostgreSQL server.
341341
342342
The returned value is a named tuple similar to that in
@@ -352,7 +352,7 @@ def get_server_version(self) -> 'types.ServerVersion':
352352
"""
353353
return self._server_version
354354

355-
def get_settings(self) -> '_cprotocol.ConnectionSettings':
355+
def get_settings(self) -> _cprotocol.ConnectionSettings:
356356
"""Return connection settings.
357357
358358
:return: :class:`~asyncpg.ConnectionSettings`.
@@ -3176,8 +3176,8 @@ class ServerCapabilities(typing.NamedTuple):
31763176

31773177

31783178
def _detect_server_capabilities(
3179-
server_version: 'types.ServerVersion',
3180-
connection_settings: '_cprotocol.ConnectionSettings'
3179+
server_version: types.ServerVersion,
3180+
connection_settings: _cprotocol.ConnectionSettings
31813181
) -> ServerCapabilities:
31823182
if hasattr(connection_settings, 'padb_revision'):
31833183
# Amazon Redshift detected.

asyncpg/exceptions/_base.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,13 @@ class UnknownPostgresError(FatalPostgresError):
5656

5757
class InterfaceMessage:
5858
args: typing.Tuple[typing.Any, ...]
59+
detail: typing.Optional[str]
60+
hint: typing.Optional[str]
5961

6062
def __init__(self, *, detail: typing.Optional[str] = None,
6163
hint: typing.Optional[str] = None) -> None:
62-
self.detail: typing.Optional[str] = detail
63-
self.hint: typing.Optional[str] = hint
64+
self.detail = detail
65+
self.hint = hint
6466

6567
def __str__(self) -> str:
6668
msg: str = self.args[0]
@@ -118,13 +120,17 @@ class ProtocolError(InternalClientError):
118120
class OutdatedSchemaCacheError(InternalClientError):
119121
"""A value decoding error caused by a schema change before row fetching."""
120122

123+
schema_name: typing.Optional[str]
124+
data_type_name: typing.Optional[str]
125+
position: typing.Optional[str]
126+
121127
def __init__(self, msg: str, *, schema: typing.Optional[str] = None,
122128
data_type: typing.Optional[str] = None,
123129
position: typing.Optional[str] = None) -> None:
124130
super().__init__(msg)
125-
self.schema_name: typing.Optional[str] = schema
126-
self.data_type_name: typing.Optional[str] = data_type
127-
self.position: typing.Optional[str] = position
131+
self.schema_name = schema
132+
self.data_type_name = data_type
133+
self.position = position
128134

129135

130136
class PostgresLogMessage(PostgresMessage):

asyncpg/pool.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from . import protocol
2020

2121

22-
logger = logging.getLogger(__name__)
22+
_logger = logging.getLogger(__name__)
2323

2424
_Connection = typing.TypeVar(
2525
'_Connection',
@@ -994,10 +994,10 @@ async def close(self) -> None:
994994
self._closing = False
995995

996996
def _warn_on_long_close(self) -> None:
997-
logger.warning('Pool.close() is taking over 60 seconds to complete. '
998-
'Check if you have any unreleased connections left. '
999-
'Use asyncio.wait_for() to set a timeout for '
1000-
'Pool.close().')
997+
_logger.warning('Pool.close() is taking over 60 seconds to complete. '
998+
'Check if you have any unreleased connections left. '
999+
'Use asyncio.wait_for() to set a timeout for '
1000+
'Pool.close().')
10011001

10021002
def terminate(self) -> None:
10031003
"""Terminate all connections in the pool."""
@@ -1060,14 +1060,19 @@ class PoolAcquireContext(typing.Generic[_Record]):
10601060

10611061
__slots__ = ('timeout', 'connection', 'done', 'pool')
10621062

1063+
timeout: typing.Optional[float]
1064+
connection: typing.Optional[
1065+
pool_connection_proxy.PoolConnectionProxy[_Record]
1066+
]
1067+
done: bool
1068+
pool: Pool[_Record]
1069+
10631070
def __init__(
10641071
self, pool: Pool[_Record], timeout: typing.Optional[float]
10651072
) -> None:
10661073
self.pool = pool
10671074
self.timeout = timeout
1068-
self.connection: typing.Optional[
1069-
pool_connection_proxy.PoolConnectionProxy[_Record]
1070-
] = None
1075+
self.connection = None
10711076
self.done = False
10721077

10731078
async def __aenter__(self) -> pool_connection_proxy.PoolConnectionProxy[

asyncpg/pool_connection_proxy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ class PoolConnectionProxy(connection._ConnectionProxy[_Record],
7979
__slots__ = ('_con', '_holder')
8080

8181
def __init__(self, holder: 'pool.PoolConnectionHolder[_Record]',
82-
con: connection.Connection):
82+
con: connection.Connection[_Record]):
8383
self._con = con
8484
self._holder = holder
8585
con._set_proxy(self)

asyncpg/serverversion.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from .compat import TypedDict
1111
from .types import ServerVersion
1212

13-
version_regex = re.compile(
13+
_version_regex = re.compile(
1414
r"(Postgre[^\s]*)?\s*"
1515
r"(?P<major>[0-9]+)\.?"
1616
r"((?P<minor>[0-9]+)\.?)?"
@@ -29,7 +29,7 @@ class _VersionDict(TypedDict):
2929

3030

3131
def split_server_version_string(version_string: str) -> ServerVersion:
32-
version_match = version_regex.search(version_string)
32+
version_match = _version_regex.search(version_string)
3333

3434
if version_match is None:
3535
raise ValueError(

asyncpg/types.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def __gt__(self, other: '_RangeValue') -> bool:
6969

7070

7171
_V = typing.TypeVar('_V', bound=_RangeValue)
72+
_R = typing.TypeVar('_R', bound='Range[typing.Any]')
7273

7374

7475
class Range(typing.Generic[_V]):
@@ -208,4 +209,4 @@ def __repr__(self) -> str:
208209

209210
return '<Range {}>'.format(desc)
210211

211-
__str__ = __repr__
212+
__str__: typing.Callable[['Range[typing.Any]'], str] = __repr__

0 commit comments

Comments
 (0)