Skip to content

Commit 4853b3b

Browse files
authored
Pyright 1.1.385 (patroni#3182)
Declaring variables with `Union` and using `isinstance()` hack doesn't work anymore. Therefore the code is updated to use `Any` for variable and `cast` function after firguring out the correct type in order to avoid getting errors about `Unknown` types.
1 parent ba970d8 commit 4853b3b

File tree

16 files changed

+132
-132
lines changed

16 files changed

+132
-132
lines changed

.github/workflows/tests.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ jobs:
186186

187187
- uses: jakebailey/pyright-action@v2
188188
with:
189-
version: 1.1.379
189+
version: 1.1.385
190190

191191
docs:
192192
runs-on: ubuntu-latest

patroni/api.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from ipaddress import ip_address, ip_network, IPv4Network, IPv6Network
2222
from socketserver import ThreadingMixIn
2323
from threading import Thread
24-
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, TYPE_CHECKING, Union
24+
from typing import Any, Callable, cast, Dict, Iterator, List, Optional, Tuple, TYPE_CHECKING, Union
2525
from urllib.parse import parse_qs, urlparse
2626

2727
import dateutil.parser
@@ -71,8 +71,8 @@ def check_access(*args: Any, **kwargs: Any) -> Callable[..., Any]:
7171
"""
7272
allowlist_check_members = kwargs.get('allowlist_check_members', True)
7373

74-
def inner_decorator(func: Callable[..., None]) -> Callable[..., None]:
75-
def wrapper(self: 'RestApiHandler', *args: Any, **kwargs: Any) -> None:
74+
def inner_decorator(func: Callable[..., Any]) -> Callable[..., Any]:
75+
def wrapper(self: 'RestApiHandler', *args: Any, **kwargs: Any) -> Any:
7676
if self.server.check_access(self, allowlist_check_members=allowlist_check_members):
7777
return func(self, *args, **kwargs)
7878

@@ -698,9 +698,9 @@ def _read_json_content(self, body_is_optional: bool = False) -> Optional[Dict[An
698698
content_length = int(self.headers.get('content-length') or 0)
699699
if content_length == 0 and body_is_optional:
700700
return {}
701-
request: Union[Dict[str, Any], Any] = json.loads(self.rfile.read(content_length).decode('utf-8'))
701+
request = json.loads(self.rfile.read(content_length).decode('utf-8'))
702702
if isinstance(request, dict) and (request or body_is_optional):
703-
return request
703+
return cast(Dict[str, Any], request)
704704
except Exception:
705705
logger.exception('Bad request')
706706
self.send_error(400)
@@ -1723,12 +1723,11 @@ def get_certificate_serial_number(self) -> Optional[str]:
17231723
17241724
:returns: serial number of the certificate configured through ``restapi.certfile`` setting.
17251725
"""
1726-
if self.__ssl_options.get('certfile'):
1726+
certfile: Optional[str] = self.__ssl_options.get('certfile')
1727+
if certfile:
17271728
import ssl
17281729
try:
1729-
crt: Dict[str, Any] = ssl._ssl._test_decode_cert(self.__ssl_options['certfile']) # pyright: ignore
1730-
if TYPE_CHECKING: # pragma: no cover
1731-
assert isinstance(crt, dict)
1730+
crt = cast(Dict[str, Any], ssl._ssl._test_decode_cert(certfile)) # pyright: ignore
17321731
return crt.get('serialNumber')
17331732
except ssl.SSLError as e:
17341733
logger.error('Failed to get serial number from certificate %s: %r', self.__ssl_options['certfile'], e)

patroni/config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from collections import defaultdict
1010
from copy import deepcopy
11-
from typing import Any, Callable, Collection, Dict, List, Optional, TYPE_CHECKING, Union
11+
from typing import Any, Callable, cast, Collection, Dict, List, Optional, TYPE_CHECKING, Union
1212

1313
import yaml
1414

@@ -695,8 +695,8 @@ def _build_effective_configuration(self, dynamic_configuration: Dict[str, Any],
695695
config = self._safe_copy_dynamic_configuration(dynamic_configuration)
696696
for name, value in local_configuration.items():
697697
if name == 'citus': # remove invalid citus configuration
698-
if isinstance(value, dict) and isinstance(value.get('group'), int)\
699-
and isinstance(value.get('database'), str):
698+
if isinstance(value, dict) and isinstance(cast(Dict[str, Any], value).get('group'), int) \
699+
and isinstance(cast(Dict[str, Any], value).get('database'), str):
700700
config[name] = value
701701
elif name == 'postgresql':
702702
for name, value in (value or {}).items():

patroni/dcs/__init__.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
from copy import deepcopy
1111
from random import randint
1212
from threading import Event, Lock
13-
from typing import Any, Callable, Collection, Dict, Iterator, List, \
14-
NamedTuple, Optional, Set, Tuple, Type, TYPE_CHECKING, Union
13+
from typing import Any, Callable, cast, Collection, Dict, Iterator, \
14+
List, NamedTuple, Optional, Set, Tuple, Type, TYPE_CHECKING, Union
1515
from urllib.parse import parse_qsl, urlparse, urlunparse
1616

1717
import dateutil.parser
@@ -219,7 +219,7 @@ def conn_url(self) -> Optional[str]:
219219

220220
return None
221221

222-
def conn_kwargs(self, auth: Union[Any, Dict[str, Any], None] = None) -> Dict[str, Any]:
222+
def conn_kwargs(self, auth: Optional[Any] = None) -> Dict[str, Any]:
223223
"""Give keyword arguments used for PostgreSQL connection settings.
224224
225225
:param auth: Authentication properties - can be defined as anything supported by the ``psycopg2`` or
@@ -255,7 +255,7 @@ def conn_kwargs(self, auth: Union[Any, Dict[str, Any], None] = None) -> Dict[str
255255

256256
# apply any remaining authentication parameters
257257
if auth and isinstance(auth, dict):
258-
ret.update({k: v for k, v in auth.items() if v is not None})
258+
ret.update({k: v for k, v in cast(Dict[str, Any], auth).items() if v is not None})
259259
if 'username' in auth:
260260
ret['user'] = ret.pop('username')
261261
return ret
@@ -949,28 +949,28 @@ def get_clone_member(self, exclude_name: str) -> Union[Member, Leader, None]:
949949
return candidates[randint(0, len(candidates) - 1)] if candidates else self.leader
950950

951951
@staticmethod
952-
def is_physical_slot(value: Union[Any, Dict[str, Any]]) -> bool:
952+
def is_physical_slot(value: Any) -> bool:
953953
"""Check whether provided configuration is for permanent physical replication slot.
954954
955955
:param value: configuration of the permanent replication slot.
956956
957957
:returns: ``True`` if *value* is a physical replication slot, otherwise ``False``.
958958
"""
959959
return not value \
960-
or (isinstance(value, dict) and not Cluster.is_logical_slot(value)
961-
and value.get('type', 'physical') == 'physical')
960+
or (isinstance(value, dict) and not Cluster.is_logical_slot(cast(Dict[str, Any], value))
961+
and cast(Dict[str, Any], value).get('type', 'physical') == 'physical')
962962

963963
@staticmethod
964-
def is_logical_slot(value: Union[Any, Dict[str, Any]]) -> bool:
964+
def is_logical_slot(value: Any) -> bool:
965965
"""Check whether provided configuration is for permanent logical replication slot.
966966
967967
:param value: configuration of the permanent replication slot.
968968
969969
:returns: ``True`` if *value* is a logical replication slot, otherwise ``False``.
970970
"""
971971
return isinstance(value, dict) \
972-
and value.get('type', 'logical') == 'logical' \
973-
and bool(value.get('database') and value.get('plugin'))
972+
and cast(Dict[str, Any], value).get('type', 'logical') == 'logical' \
973+
and bool(cast(Dict[str, Any], value).get('database') and cast(Dict[str, Any], value).get('plugin'))
974974

975975
@property
976976
def __permanent_slots(self) -> Dict[str, Union[Dict[str, Any], Any]]:
@@ -992,7 +992,7 @@ def __permanent_slots(self) -> Dict[str, Union[Dict[str, Any], Any]]:
992992
value['lsn'] = lsn
993993
else:
994994
# Don't let anyone set 'lsn' in the global configuration :)
995-
value.pop('lsn', None)
995+
value.pop('lsn', None) # pyright: ignore [reportUnknownMemberType]
996996
return ret
997997

998998
@property
@@ -1066,8 +1066,9 @@ def _merge_permanent_slots(self, slots: Dict[str, Dict[str, Any]], permanent_slo
10661066
logger.error("Slot name may only contain lower case letters, numbers, and the underscore chars")
10671067
continue
10681068

1069-
value = deepcopy(value) if value else {'type': 'physical'}
1070-
if isinstance(value, dict):
1069+
tmp = deepcopy(value) if value else {'type': 'physical'}
1070+
if isinstance(tmp, dict):
1071+
value = cast(Dict[str, Any], tmp)
10711072
if 'type' not in value:
10721073
value['type'] = 'logical' if value.get('database') and value.get('plugin') else 'physical'
10731074

patroni/dcs/etcd3.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,7 @@ class InvalidAuthToken(Etcd3ClientError):
150150
errCodeToClientError = {getattr(s, 'code'): s for s in Etcd3ClientError.__subclasses__()}
151151

152152

153-
def _raise_for_data(data: Union[bytes, str, Dict[str, Union[Any, Dict[str, Any]]]],
154-
status_code: Optional[int] = None) -> Etcd3ClientError:
153+
def _raise_for_data(data: Union[bytes, str, Dict[str, Any]], status_code: Optional[int] = None) -> Etcd3ClientError:
155154
try:
156155
if TYPE_CHECKING: # pragma: no cover
157156
assert isinstance(data, dict)

patroni/dcs/exhibitor.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import random
44
import time
55

6-
from typing import Any, Callable, Dict, List, Union
6+
from typing import Any, Callable, cast, Dict, List, Union
77

88
from ..postgresql.mpp import AbstractMPP
99
from ..request import get as requests_get
@@ -41,16 +41,17 @@ def poll(self) -> bool:
4141

4242
if isinstance(json, dict) and 'servers' in json and 'port' in json:
4343
self._next_poll = time.time() + self._poll_interval
44-
servers: List[str] = json['servers']
45-
zookeeper_hosts = ','.join([h + ':' + str(json['port']) for h in sorted(servers)])
44+
servers: List[str] = cast(Dict[str, Any], json)['servers']
45+
port = str(cast(Dict[str, Any], json)['port'])
46+
zookeeper_hosts = ','.join([h + ':' + port for h in sorted(servers)])
4647
if self._zookeeper_hosts != zookeeper_hosts:
4748
logger.info('ZooKeeper connection string has changed: %s => %s', self._zookeeper_hosts, zookeeper_hosts)
4849
self._zookeeper_hosts = zookeeper_hosts
4950
self._exhibitors = json['servers']
5051
return True
5152
return False
5253

53-
def _query_exhibitors(self, exhibitors: List[str]) -> Union[Dict[str, Any], Any]:
54+
def _query_exhibitors(self, exhibitors: List[str]) -> Any:
5455
random.shuffle(exhibitors)
5556
for host in exhibitors:
5657
try:

patroni/dcs/kubernetes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -648,7 +648,7 @@ def get(self, name: str) -> Optional[K8sObject]:
648648
with self._object_cache_lock:
649649
return self._object_cache.get(name)
650650

651-
def _process_event(self, event: Dict[str, Union[Any, Dict[str, Union[Any, Dict[str, Any]]]]]) -> None:
651+
def _process_event(self, event: Dict[str, Any]) -> None:
652652
ev_type = event['type']
653653
obj = event['object']
654654
name = obj['metadata']['name']

patroni/dcs/raft.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ def callback(*args: Any) -> None:
255255
self.__limb.pop(key)
256256
self._expire(key, value, callback=callback)
257257

258-
def get(self, key: str, recursive: bool = False) -> Union[None, Dict[str, Any], Dict[str, Dict[str, Any]]]:
258+
def get(self, key: str, recursive: bool = False) -> Optional[Dict[str, Any]]:
259259
if not recursive:
260260
return self.__data.get(key)
261261
return {k: v for k, v in self.__data.items() if k.startswith(key)}

patroni/dcs/zookeeper.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import socket
55
import time
66

7-
from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
7+
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
88

99
from kazoo.client import KazooClient, KazooRetry, KazooState
1010
from kazoo.exceptions import ConnectionClosedError, NodeExistsError, NoNodeError, SessionExpiredError
@@ -180,7 +180,8 @@ def ttl(self) -> int:
180180
return int(self._client._session_timeout / 1000.0)
181181

182182
def set_retry_timeout(self, retry_timeout: int) -> None:
183-
retry = self._client.retry if isinstance(self._client.retry, KazooRetry) else self._client._retry
183+
old_kazoo = isinstance(self._client.retry, KazooRetry) # pyright: ignore [reportUnnecessaryIsInstance]
184+
retry = cast(KazooRetry, self._client.retry) if old_kazoo else self._client._retry
184185
retry.deadline = retry_timeout
185186

186187
def get_node(

patroni/global_config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import types
99

1010
from copy import deepcopy
11-
from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union
11+
from typing import Any, cast, Dict, List, Optional, TYPE_CHECKING
1212

1313
from .collections import EMPTY_DICT
1414
from .utils import parse_bool, parse_int
@@ -121,7 +121,7 @@ def is_synchronous_mode_strict(self) -> bool:
121121
"""``True`` if at least one synchronous node is required."""
122122
return self.check_mode('synchronous_mode_strict')
123123

124-
def get_standby_cluster_config(self) -> Union[Dict[str, Any], Any]:
124+
def get_standby_cluster_config(self) -> Any:
125125
"""Get ``standby_cluster`` configuration.
126126
127127
:returns: a copy of ``standby_cluster`` configuration.
@@ -133,7 +133,7 @@ def is_standby_cluster(self) -> bool:
133133
"""``True`` if global configuration has a valid ``standby_cluster`` section."""
134134
config = self.get_standby_cluster_config()
135135
return isinstance(config, dict) and\
136-
bool(config.get('host') or config.get('port') or config.get('restore_command'))
136+
any(cast(Dict[str, Any], config).get(p) for p in ('host', 'port', 'restore_command'))
137137

138138
def get_int(self, name: str, default: int = 0, base_unit: Optional[str] = None) -> int:
139139
"""Gets current value of *name* from the global configuration and try to return it as :class:`int`.

0 commit comments

Comments
 (0)