Skip to content

Commit 9718afc

Browse files
committed
Remove some ignore comments and clean up typings
1 parent 465e357 commit 9718afc

File tree

4 files changed

+74
-49
lines changed

4 files changed

+74
-49
lines changed

asyncpg/connect_utils.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,13 @@
2929
from . import exceptions
3030
from . import protocol
3131

32-
_Connection = typing.TypeVar('_Connection')
32+
if typing.TYPE_CHECKING:
33+
from . import connection
34+
35+
_Connection = typing.TypeVar(
36+
'_Connection',
37+
bound='connection.Connection[typing.Any]'
38+
)
3339
_Protocol = typing.TypeVar('_Protocol', bound='protocol.Protocol[typing.Any]')
3440
_Record = typing.TypeVar('_Record', bound=protocol.Record)
3541
_SSLMode = typing.TypeVar('_SSLMode', bound='SSLMode')
@@ -979,7 +985,7 @@ async def __connect_addr(
979985
tr.close()
980986
raise
981987

982-
con = connection_class( # type: ignore[call-arg]
988+
con = connection_class(
983989
pr, tr, loop, addr, config, params_input
984990
)
985991
pr.set_connection(con)

asyncpg/connection.py

Lines changed: 39 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -525,27 +525,20 @@ async def _get_statement(
525525
named: typing.Union[bool, str] = False,
526526
use_cache: bool = True,
527527
ignore_custom_codec: bool = False,
528-
record_class: typing.Optional[typing.Type[_OtherRecord]] = None
528+
record_class: typing.Optional[typing.Type[typing.Any]] = None
529529
) -> '_cprotocol.PreparedStatementState[typing.Any]':
530-
record_cls: typing.Optional[
531-
typing.Union[
532-
typing.Type[_Record],
533-
typing.Type[_OtherRecord]
534-
]
535-
] = record_class
536-
537-
if record_cls is None:
538-
record_cls = self._protocol.get_record_class()
530+
if record_class is None:
531+
record_class = self._protocol.get_record_class()
539532
else:
540-
_check_record_class(record_cls)
533+
_check_record_class(record_class)
541534

542535
if use_cache:
543-
statement: typing.Optional[
544-
'_cprotocol.PreparedStatementState[typing.Any]'
545-
] = self._stmt_cache.get((query, record_cls, ignore_custom_codec))
536+
cache_statement = self._stmt_cache.get(
537+
(query, record_class, ignore_custom_codec)
538+
)
546539

547-
if statement is not None:
548-
return statement
540+
if cache_statement is not None:
541+
return cache_statement
549542

550543
# Only use the cache when:
551544
# * `statement_cache_size` is greater than 0;
@@ -567,7 +560,7 @@ async def _get_statement(
567560
stmt_name,
568561
query,
569562
timeout,
570-
record_class=record_cls,
563+
record_class=record_class,
571564
ignore_custom_codec=ignore_custom_codec,
572565
)
573566
need_reprepare = False
@@ -609,12 +602,12 @@ async def _get_statement(
609602
query,
610603
timeout,
611604
state=statement,
612-
record_class=record_cls,
605+
record_class=record_class,
613606
)
614607

615608
if use_cache:
616609
self._stmt_cache.put(
617-
(query, record_cls, ignore_custom_codec), statement)
610+
(query, record_class, ignore_custom_codec), statement)
618611

619612
# If we've just created a new statement object, check if there
620613
# are any statements for GC.
@@ -2595,10 +2588,10 @@ async def connect(dsn: typing.Optional[str] = None, *,
25952588
command_timeout: typing.Optional[float] = None,
25962589
ssl: typing.Optional[connect_utils.SSLType] = None,
25972590
direct_tls: bool = False,
2598-
connection_class: typing.Type[_Connection] = Connection, # type: ignore[assignment] # noqa: E501
2599-
record_class: typing.Type[_Record] = protocol.Record, # type: ignore[assignment] # noqa: E501
2591+
connection_class: typing.Type[typing.Any] = Connection,
2592+
record_class: typing.Type[typing.Any] = protocol.Record,
26002593
server_settings: typing.Optional[
2601-
typing.Dict[str, str]] = None) -> _Connection:
2594+
typing.Dict[str, str]] = None) -> Connection[typing.Any]:
26022595
r"""A coroutine to establish a connection to a PostgreSQL server.
26032596
26042597
The connection parameters may be specified either as a connection
@@ -2921,7 +2914,7 @@ async def connect(dsn: typing.Optional[str] = None, *,
29212914
_StatementCacheKey = typing.Tuple[str, typing.Type[_Record], bool]
29222915

29232916

2924-
class _StatementCacheEntry:
2917+
class _StatementCacheEntry(typing.Generic[_Record]):
29252918

29262919
__slots__ = ('_query', '_statement', '_cache', '_cleanup_cb')
29272920

@@ -2968,8 +2961,8 @@ def __init__(self, *, loop: asyncio.AbstractEventLoop,
29682961
# entries dict, whereas the unused one will group in the
29692962
# beginning of it.
29702963
self._entries: collections.OrderedDict[
2971-
_StatementCacheKey['_cprotocol.Record'],
2972-
_StatementCacheEntry
2964+
_StatementCacheKey[typing.Any],
2965+
_StatementCacheEntry[typing.Any]
29732966
] = collections.OrderedDict()
29742967

29752968
def __len__(self) -> int:
@@ -3004,7 +2997,9 @@ def get(
30042997
# The cache is disabled.
30052998
return None
30062999

3007-
entry: typing.Optional[_StatementCacheEntry] = self._entries.get(query)
3000+
entry: typing.Optional[
3001+
_StatementCacheEntry[_Record]
3002+
] = self._entries.get(query)
30083003
if entry is None:
30093004
return None
30103005

@@ -3058,7 +3053,9 @@ def clear(self) -> None:
30583053
self._clear_entry_callback(entry)
30593054
self._on_remove(entry._statement)
30603055

3061-
def _set_entry_timeout(self, entry: _StatementCacheEntry) -> None:
3056+
def _set_entry_timeout(
3057+
self, entry: _StatementCacheEntry[typing.Any]
3058+
) -> None:
30623059
# Clear the existing timeout.
30633060
self._clear_entry_callback(entry)
30643061

@@ -3071,19 +3068,23 @@ def _new_entry(
30713068
self,
30723069
query: _StatementCacheKey[_Record],
30733070
statement: '_cprotocol.PreparedStatementState[_Record]'
3074-
) -> _StatementCacheEntry:
3071+
) -> _StatementCacheEntry[_Record]:
30753072
entry = _StatementCacheEntry(self, query, statement)
30763073
self._set_entry_timeout(entry)
30773074
return entry
30783075

3079-
def _on_entry_expired(self, entry: _StatementCacheEntry) -> None:
3076+
def _on_entry_expired(
3077+
self, entry: _StatementCacheEntry[typing.Any]
3078+
) -> None:
30803079
# `call_later` callback, called when an entry stayed longer
30813080
# than `self._max_lifetime`.
30823081
if self._entries.get(entry._query) is entry:
30833082
self._entries.pop(entry._query)
30843083
self._on_remove(entry._statement)
30853084

3086-
def _clear_entry_callback(self, entry: _StatementCacheEntry) -> None:
3085+
def _clear_entry_callback(
3086+
self, entry: _StatementCacheEntry[typing.Any]
3087+
) -> None:
30873088
if entry._cleanup_cb is not None:
30883089
entry._cleanup_cb.cancel()
30893090

@@ -3213,22 +3214,26 @@ def _extract_stack(limit: int = 10) -> str:
32133214
"""
32143215
frame = sys._getframe().f_back
32153216
try:
3216-
stack = traceback.StackSummary.extract(
3217-
traceback.walk_stack(frame), lookup_lines=False)
3217+
stack: typing.List[
3218+
traceback.FrameSummary
3219+
] = traceback.StackSummary.extract(
3220+
traceback.walk_stack(frame), lookup_lines=False
3221+
)
32183222
finally:
32193223
del frame
32203224

32213225
apg_path = asyncpg.__path__[0]
32223226
i = 0
32233227
while i < len(stack) and stack[i][0].startswith(apg_path):
32243228
i += 1
3225-
stack = stack[i:i + limit] # type: ignore[assignment]
3229+
3230+
stack = stack[i:i + limit]
32263231

32273232
stack.reverse()
32283233
return ''.join(traceback.format_list(stack))
32293234

32303235

3231-
def _check_record_class(record_class: typing.Type[protocol.Record]) -> None:
3236+
def _check_record_class(record_class: typing.Type[typing.Any]) -> None:
32323237
if record_class is protocol.Record:
32333238
pass
32343239
elif (

asyncpg/protocol/protocol.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ class BaseProtocol(CoreProtocol, Generic[_Record]):
220220
state: Optional[_PreparedStatementState] = ...,
221221
ignore_custom_codec: bool = ...,
222222
record_class: Optional[type[_OtherRecord]]
223-
) -> Union[_PreparedStatementState, type[_OtherRecord]]: ...
223+
) -> Union[_PreparedStatementState, PreparedStatementState[_OtherRecord]]: ...
224224
async def query(self, *args: Any, **kwargs: Any) -> str: ...
225225
def resume_writing(self, *args: Any, **kwargs: Any) -> Any: ...
226226
def __reduce__(self) -> Any: ...

asyncpg/types.py

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
Box, Line, LineSegment, Circle,
1313
)
1414

15+
from . import compat
16+
1517

1618
__all__ = (
1719
'Type', 'Attribute', 'Range', 'BitString', 'Point', 'Path', 'Polygon',
@@ -54,16 +56,28 @@ class ServerVersion(typing.NamedTuple):
5456

5557
ServerVersion.__doc__ = 'PostgreSQL server version tuple.'
5658

57-
T = typing.TypeVar('T')
59+
60+
class _RangeValue(compat.Protocol):
61+
def __eq__(self, other: object) -> bool:
62+
...
63+
64+
def __lt__(self, other: '_RangeValue') -> bool:
65+
...
66+
67+
def __gt__(self, other: '_RangeValue') -> bool:
68+
...
69+
70+
71+
_V = typing.TypeVar('_V', bound=_RangeValue)
5872

5973

60-
class Range(typing.Generic[T]):
74+
class Range(typing.Generic[_V]):
6175
"""Immutable representation of PostgreSQL `range` type."""
6276

6377
__slots__ = '_lower', '_upper', '_lower_inc', '_upper_inc', '_empty'
6478

65-
def __init__(self, lower: typing.Optional[T] = None,
66-
upper: typing.Optional[T] = None, *,
79+
def __init__(self, lower: typing.Optional[_V] = None,
80+
upper: typing.Optional[_V] = None, *,
6781
lower_inc: bool = True,
6882
upper_inc: bool = False,
6983
empty: bool = False) -> None:
@@ -78,7 +92,7 @@ def __init__(self, lower: typing.Optional[T] = None,
7892
self._upper_inc = upper is not None and upper_inc
7993

8094
@property
81-
def lower(self) -> typing.Optional[T]:
95+
def lower(self) -> typing.Optional[_V]:
8296
return self._lower
8397

8498
@property
@@ -90,7 +104,7 @@ def lower_inf(self) -> bool:
90104
return self._lower is None and not self._empty
91105

92106
@property
93-
def upper(self) -> typing.Optional[T]:
107+
def upper(self) -> typing.Optional[_V]:
94108
return self._upper
95109

96110
@property
@@ -105,37 +119,37 @@ def upper_inf(self) -> bool:
105119
def isempty(self) -> bool:
106120
return self._empty
107121

108-
def _issubset_lower(self, other: 'Range[T]') -> bool:
122+
def _issubset_lower(self, other: 'Range[_V]') -> bool:
109123
if other._lower is None:
110124
return True
111125
if self._lower is None:
112126
return False
113127

114-
return self._lower > other._lower or ( # type: ignore[operator]
128+
return self._lower > other._lower or (
115129
self._lower == other._lower
116130
and (other._lower_inc or not self._lower_inc)
117131
)
118132

119-
def _issubset_upper(self, other: 'Range[T]') -> bool:
133+
def _issubset_upper(self, other: 'Range[_V]') -> bool:
120134
if other._upper is None:
121135
return True
122136
if self._upper is None:
123137
return False
124138

125-
return self._upper < other._upper or ( # type: ignore[operator]
139+
return self._upper < other._upper or (
126140
self._upper == other._upper
127141
and (other._upper_inc or not self._upper_inc)
128142
)
129143

130-
def issubset(self, other: 'Range[T]') -> bool:
144+
def issubset(self, other: 'Range[_V]') -> bool:
131145
if self._empty:
132146
return True
133147
if other._empty:
134148
return False
135149

136150
return self._issubset_lower(other) and self._issubset_upper(other)
137151

138-
def issuperset(self, other: 'Range[T]') -> bool:
152+
def issuperset(self, other: 'Range[_V]') -> bool:
139153
return other.issubset(self)
140154

141155
def __bool__(self) -> bool:

0 commit comments

Comments
 (0)