@@ -525,27 +525,20 @@ async def _get_statement(
525
525
named : typing .Union [bool , str ] = False ,
526
526
use_cache : bool = True ,
527
527
ignore_custom_codec : bool = False ,
528
- record_class : typing .Optional [typing .Type [_OtherRecord ]] = None
528
+ record_class : typing .Optional [typing .Type [typing . Any ]] = None
529
529
) -> '_cprotocol.PreparedStatementState[typing.Any]' :
530
- record_cls : typing .Optional [
531
- typing .Union [
532
- typing .Type [_Record ],
533
- typing .Type [_OtherRecord ]
534
- ]
535
- ] = record_class
536
-
537
- if record_cls is None :
538
- record_cls = self ._protocol .get_record_class ()
530
+ if record_class is None :
531
+ record_class = self ._protocol .get_record_class ()
539
532
else :
540
- _check_record_class (record_cls )
533
+ _check_record_class (record_class )
541
534
542
535
if use_cache :
543
- statement : typing . Optional [
544
- '_cprotocol.PreparedStatementState[typing.Any]'
545
- ] = self . _stmt_cache . get (( query , record_cls , ignore_custom_codec ) )
536
+ cache_statement = self . _stmt_cache . get (
537
+ ( query , record_class , ignore_custom_codec )
538
+ )
546
539
547
- if statement is not None :
548
- return statement
540
+ if cache_statement is not None :
541
+ return cache_statement
549
542
550
543
# Only use the cache when:
551
544
# * `statement_cache_size` is greater than 0;
@@ -567,7 +560,7 @@ async def _get_statement(
567
560
stmt_name ,
568
561
query ,
569
562
timeout ,
570
- record_class = record_cls ,
563
+ record_class = record_class ,
571
564
ignore_custom_codec = ignore_custom_codec ,
572
565
)
573
566
need_reprepare = False
@@ -609,12 +602,12 @@ async def _get_statement(
609
602
query ,
610
603
timeout ,
611
604
state = statement ,
612
- record_class = record_cls ,
605
+ record_class = record_class ,
613
606
)
614
607
615
608
if use_cache :
616
609
self ._stmt_cache .put (
617
- (query , record_cls , ignore_custom_codec ), statement )
610
+ (query , record_class , ignore_custom_codec ), statement )
618
611
619
612
# If we've just created a new statement object, check if there
620
613
# are any statements for GC.
@@ -2595,10 +2588,10 @@ async def connect(dsn: typing.Optional[str] = None, *,
2595
2588
command_timeout : typing .Optional [float ] = None ,
2596
2589
ssl : typing .Optional [connect_utils .SSLType ] = None ,
2597
2590
direct_tls : bool = False ,
2598
- connection_class : typing .Type [_Connection ] = Connection , # type: ignore[assignment] # noqa: E501
2599
- record_class : typing .Type [_Record ] = protocol .Record , # type: ignore[assignment] # noqa: E501
2591
+ connection_class : typing .Type [typing . Any ] = Connection ,
2592
+ record_class : typing .Type [typing . Any ] = protocol .Record ,
2600
2593
server_settings : typing .Optional [
2601
- typing .Dict [str , str ]] = None ) -> _Connection :
2594
+ typing .Dict [str , str ]] = None ) -> Connection [ typing . Any ] :
2602
2595
r"""A coroutine to establish a connection to a PostgreSQL server.
2603
2596
2604
2597
The connection parameters may be specified either as a connection
@@ -2921,7 +2914,7 @@ async def connect(dsn: typing.Optional[str] = None, *,
2921
2914
_StatementCacheKey = typing .Tuple [str , typing .Type [_Record ], bool ]
2922
2915
2923
2916
2924
- class _StatementCacheEntry :
2917
+ class _StatementCacheEntry ( typing . Generic [ _Record ]) :
2925
2918
2926
2919
__slots__ = ('_query' , '_statement' , '_cache' , '_cleanup_cb' )
2927
2920
@@ -2968,8 +2961,8 @@ def __init__(self, *, loop: asyncio.AbstractEventLoop,
2968
2961
# entries dict, whereas the unused one will group in the
2969
2962
# beginning of it.
2970
2963
self ._entries : collections .OrderedDict [
2971
- _StatementCacheKey ['_cprotocol.Record' ],
2972
- _StatementCacheEntry
2964
+ _StatementCacheKey [typing . Any ],
2965
+ _StatementCacheEntry [ typing . Any ]
2973
2966
] = collections .OrderedDict ()
2974
2967
2975
2968
def __len__ (self ) -> int :
@@ -3004,7 +2997,9 @@ def get(
3004
2997
# The cache is disabled.
3005
2998
return None
3006
2999
3007
- entry : typing .Optional [_StatementCacheEntry ] = self ._entries .get (query )
3000
+ entry : typing .Optional [
3001
+ _StatementCacheEntry [_Record ]
3002
+ ] = self ._entries .get (query )
3008
3003
if entry is None :
3009
3004
return None
3010
3005
@@ -3058,7 +3053,9 @@ def clear(self) -> None:
3058
3053
self ._clear_entry_callback (entry )
3059
3054
self ._on_remove (entry ._statement )
3060
3055
3061
- def _set_entry_timeout (self , entry : _StatementCacheEntry ) -> None :
3056
+ def _set_entry_timeout (
3057
+ self , entry : _StatementCacheEntry [typing .Any ]
3058
+ ) -> None :
3062
3059
# Clear the existing timeout.
3063
3060
self ._clear_entry_callback (entry )
3064
3061
@@ -3071,19 +3068,23 @@ def _new_entry(
3071
3068
self ,
3072
3069
query : _StatementCacheKey [_Record ],
3073
3070
statement : '_cprotocol.PreparedStatementState[_Record]'
3074
- ) -> _StatementCacheEntry :
3071
+ ) -> _StatementCacheEntry [ _Record ] :
3075
3072
entry = _StatementCacheEntry (self , query , statement )
3076
3073
self ._set_entry_timeout (entry )
3077
3074
return entry
3078
3075
3079
- def _on_entry_expired (self , entry : _StatementCacheEntry ) -> None :
3076
+ def _on_entry_expired (
3077
+ self , entry : _StatementCacheEntry [typing .Any ]
3078
+ ) -> None :
3080
3079
# `call_later` callback, called when an entry stayed longer
3081
3080
# than `self._max_lifetime`.
3082
3081
if self ._entries .get (entry ._query ) is entry :
3083
3082
self ._entries .pop (entry ._query )
3084
3083
self ._on_remove (entry ._statement )
3085
3084
3086
- def _clear_entry_callback (self , entry : _StatementCacheEntry ) -> None :
3085
+ def _clear_entry_callback (
3086
+ self , entry : _StatementCacheEntry [typing .Any ]
3087
+ ) -> None :
3087
3088
if entry ._cleanup_cb is not None :
3088
3089
entry ._cleanup_cb .cancel ()
3089
3090
@@ -3213,22 +3214,26 @@ def _extract_stack(limit: int = 10) -> str:
3213
3214
"""
3214
3215
frame = sys ._getframe ().f_back
3215
3216
try :
3216
- stack = traceback .StackSummary .extract (
3217
- traceback .walk_stack (frame ), lookup_lines = False )
3217
+ stack : typing .List [
3218
+ traceback .FrameSummary
3219
+ ] = traceback .StackSummary .extract (
3220
+ traceback .walk_stack (frame ), lookup_lines = False
3221
+ )
3218
3222
finally :
3219
3223
del frame
3220
3224
3221
3225
apg_path = asyncpg .__path__ [0 ]
3222
3226
i = 0
3223
3227
while i < len (stack ) and stack [i ][0 ].startswith (apg_path ):
3224
3228
i += 1
3225
- stack = stack [i :i + limit ] # type: ignore[assignment]
3229
+
3230
+ stack = stack [i :i + limit ]
3226
3231
3227
3232
stack .reverse ()
3228
3233
return '' .join (traceback .format_list (stack ))
3229
3234
3230
3235
3231
- def _check_record_class (record_class : typing .Type [protocol . Record ]) -> None :
3236
+ def _check_record_class (record_class : typing .Type [typing . Any ]) -> None :
3232
3237
if record_class is protocol .Record :
3233
3238
pass
3234
3239
elif (
0 commit comments