@@ -525,27 +525,20 @@ async def _get_statement(
525525 named : typing .Union [bool , str ] = False ,
526526 use_cache : bool = True ,
527527 ignore_custom_codec : bool = False ,
528- record_class : typing .Optional [typing .Type [_OtherRecord ]] = None
528+ record_class : typing .Optional [typing .Type [typing . Any ]] = None
529529 ) -> '_cprotocol.PreparedStatementState[typing.Any]' :
530- record_cls : typing .Optional [
531- typing .Union [
532- typing .Type [_Record ],
533- typing .Type [_OtherRecord ]
534- ]
535- ] = record_class
536-
537- if record_cls is None :
538- record_cls = self ._protocol .get_record_class ()
530+ if record_class is None :
531+ record_class = self ._protocol .get_record_class ()
539532 else :
540- _check_record_class (record_cls )
533+ _check_record_class (record_class )
541534
542535 if use_cache :
543- statement : typing . Optional [
544- '_cprotocol.PreparedStatementState[typing.Any]'
545- ] = self . _stmt_cache . get (( query , record_cls , ignore_custom_codec ) )
536+ cache_statement = self . _stmt_cache . get (
537+ ( query , record_class , ignore_custom_codec )
538+ )
546539
547- if statement is not None :
548- return statement
540+ if cache_statement is not None :
541+ return cache_statement
549542
550543 # Only use the cache when:
551544 # * `statement_cache_size` is greater than 0;
@@ -567,7 +560,7 @@ async def _get_statement(
567560 stmt_name ,
568561 query ,
569562 timeout ,
570- record_class = record_cls ,
563+ record_class = record_class ,
571564 ignore_custom_codec = ignore_custom_codec ,
572565 )
573566 need_reprepare = False
@@ -609,12 +602,12 @@ async def _get_statement(
609602 query ,
610603 timeout ,
611604 state = statement ,
612- record_class = record_cls ,
605+ record_class = record_class ,
613606 )
614607
615608 if use_cache :
616609 self ._stmt_cache .put (
617- (query , record_cls , ignore_custom_codec ), statement )
610+ (query , record_class , ignore_custom_codec ), statement )
618611
619612 # If we've just created a new statement object, check if there
620613 # are any statements for GC.
@@ -2595,10 +2588,10 @@ async def connect(dsn: typing.Optional[str] = None, *,
25952588 command_timeout : typing .Optional [float ] = None ,
25962589 ssl : typing .Optional [connect_utils .SSLType ] = None ,
25972590 direct_tls : bool = False ,
2598- connection_class : typing .Type [_Connection ] = Connection , # type: ignore[assignment] # noqa: E501
2599- record_class : typing .Type [_Record ] = protocol .Record , # type: ignore[assignment] # noqa: E501
2591+ connection_class : typing .Type [typing . Any ] = Connection ,
2592+ record_class : typing .Type [typing . Any ] = protocol .Record ,
26002593 server_settings : typing .Optional [
2601- typing .Dict [str , str ]] = None ) -> _Connection :
2594+ typing .Dict [str , str ]] = None ) -> Connection [ typing . Any ] :
26022595 r"""A coroutine to establish a connection to a PostgreSQL server.
26032596
26042597 The connection parameters may be specified either as a connection
@@ -2921,7 +2914,7 @@ async def connect(dsn: typing.Optional[str] = None, *,
29212914_StatementCacheKey = typing .Tuple [str , typing .Type [_Record ], bool ]
29222915
29232916
2924- class _StatementCacheEntry :
2917+ class _StatementCacheEntry ( typing . Generic [ _Record ]) :
29252918
29262919 __slots__ = ('_query' , '_statement' , '_cache' , '_cleanup_cb' )
29272920
@@ -2968,8 +2961,8 @@ def __init__(self, *, loop: asyncio.AbstractEventLoop,
29682961 # entries dict, whereas the unused one will group in the
29692962 # beginning of it.
29702963 self ._entries : collections .OrderedDict [
2971- _StatementCacheKey ['_cprotocol.Record' ],
2972- _StatementCacheEntry
2964+ _StatementCacheKey [typing . Any ],
2965+ _StatementCacheEntry [ typing . Any ]
29732966 ] = collections .OrderedDict ()
29742967
29752968 def __len__ (self ) -> int :
@@ -3004,7 +2997,9 @@ def get(
30042997 # The cache is disabled.
30052998 return None
30062999
3007- entry : typing .Optional [_StatementCacheEntry ] = self ._entries .get (query )
3000+ entry : typing .Optional [
3001+ _StatementCacheEntry [_Record ]
3002+ ] = self ._entries .get (query )
30083003 if entry is None :
30093004 return None
30103005
@@ -3058,7 +3053,9 @@ def clear(self) -> None:
30583053 self ._clear_entry_callback (entry )
30593054 self ._on_remove (entry ._statement )
30603055
3061- def _set_entry_timeout (self , entry : _StatementCacheEntry ) -> None :
3056+ def _set_entry_timeout (
3057+ self , entry : _StatementCacheEntry [typing .Any ]
3058+ ) -> None :
30623059 # Clear the existing timeout.
30633060 self ._clear_entry_callback (entry )
30643061
@@ -3071,19 +3068,23 @@ def _new_entry(
30713068 self ,
30723069 query : _StatementCacheKey [_Record ],
30733070 statement : '_cprotocol.PreparedStatementState[_Record]'
3074- ) -> _StatementCacheEntry :
3071+ ) -> _StatementCacheEntry [ _Record ] :
30753072 entry = _StatementCacheEntry (self , query , statement )
30763073 self ._set_entry_timeout (entry )
30773074 return entry
30783075
3079- def _on_entry_expired (self , entry : _StatementCacheEntry ) -> None :
3076+ def _on_entry_expired (
3077+ self , entry : _StatementCacheEntry [typing .Any ]
3078+ ) -> None :
30803079 # `call_later` callback, called when an entry stayed longer
30813080 # than `self._max_lifetime`.
30823081 if self ._entries .get (entry ._query ) is entry :
30833082 self ._entries .pop (entry ._query )
30843083 self ._on_remove (entry ._statement )
30853084
3086- def _clear_entry_callback (self , entry : _StatementCacheEntry ) -> None :
3085+ def _clear_entry_callback (
3086+ self , entry : _StatementCacheEntry [typing .Any ]
3087+ ) -> None :
30873088 if entry ._cleanup_cb is not None :
30883089 entry ._cleanup_cb .cancel ()
30893090
@@ -3213,22 +3214,26 @@ def _extract_stack(limit: int = 10) -> str:
32133214 """
32143215 frame = sys ._getframe ().f_back
32153216 try :
3216- stack = traceback .StackSummary .extract (
3217- traceback .walk_stack (frame ), lookup_lines = False )
3217+ stack : typing .List [
3218+ traceback .FrameSummary
3219+ ] = traceback .StackSummary .extract (
3220+ traceback .walk_stack (frame ), lookup_lines = False
3221+ )
32183222 finally :
32193223 del frame
32203224
32213225 apg_path = asyncpg .__path__ [0 ]
32223226 i = 0
32233227 while i < len (stack ) and stack [i ][0 ].startswith (apg_path ):
32243228 i += 1
3225- stack = stack [i :i + limit ] # type: ignore[assignment]
3229+
3230+ stack = stack [i :i + limit ]
32263231
32273232 stack .reverse ()
32283233 return '' .join (traceback .format_list (stack ))
32293234
32303235
3231- def _check_record_class (record_class : typing .Type [protocol . Record ]) -> None :
3236+ def _check_record_class (record_class : typing .Type [typing . Any ]) -> None :
32323237 if record_class is protocol .Record :
32333238 pass
32343239 elif (
0 commit comments