Skip to content

Commit e992d25

Browse files
committed
refactor[py]: Add type hints to remote_connection.py
- Added explicit type annotations to selenium.webdriver.remote.remote_connection.py - Improved code clarity and static type checking - Ensured compatibility with modern type checkers like Pyright and Mypy
1 parent 018d2c8 commit e992d25

File tree

1 file changed

+18
-11
lines changed

1 file changed

+18
-11
lines changed

py/selenium/webdriver/remote/remote_connection.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@
2020
import string
2121
import warnings
2222
from base64 import b64encode
23+
from typing import Any
2324
from typing import Optional
25+
from typing import TypeVar
2426
from urllib import parse
2527
from urllib.parse import urlparse
2628

@@ -35,6 +37,11 @@
3537

3638
LOGGER = logging.getLogger(__name__)
3739

40+
# TODO: Replace with 'Self' when Python 3.11+ is supported.
41+
# from typing import Self
42+
43+
RemoteConnectionType = TypeVar("RemoteConnectionType", bound="RemoteConnection")
44+
3845
remote_commands = {
3946
Command.NEW_SESSION: ("POST", "/session"),
4047
Command.QUIT: ("DELETE", "/session/$sessionId"),
@@ -158,7 +165,7 @@ class RemoteConnection:
158165
else socket.getdefaulttimeout()
159166
)
160167
_ca_certs = os.getenv("REQUESTS_CA_BUNDLE") if "REQUESTS_CA_BUNDLE" in os.environ else certifi.where()
161-
_client_config: ClientConfig = None
168+
_client_config: ClientConfig | None = None
162169

163170
system = platform.system().lower()
164171
if system == "darwin":
@@ -169,7 +176,7 @@ class RemoteConnection:
169176
user_agent = f"selenium/{__version__} (python {system})"
170177

171178
@classmethod
172-
def get_timeout(cls):
179+
def get_timeout(cls) -> float | int | None:
173180
""":Returns:
174181
175182
Timeout value in seconds for all http requests made to the
@@ -183,7 +190,7 @@ def get_timeout(cls):
183190
return cls._client_config.timeout
184191

185192
@classmethod
186-
def set_timeout(cls, timeout):
193+
def set_timeout(cls, timeout: int | float):
187194
"""Override the default timeout.
188195
189196
:Args:
@@ -207,7 +214,7 @@ def reset_timeout(cls):
207214
cls._client_config.reset_timeout()
208215

209216
@classmethod
210-
def get_certificate_bundle_path(cls):
217+
def get_certificate_bundle_path(cls) -> str:
211218
""":Returns:
212219
213220
Paths of the .pem encoded certificate to verify connection to
@@ -222,7 +229,7 @@ def get_certificate_bundle_path(cls):
222229
return cls._client_config.ca_certs
223230

224231
@classmethod
225-
def set_certificate_bundle_path(cls, path):
232+
def set_certificate_bundle_path(cls, path: str):
226233
"""Set the path to the certificate bundle to verify connection to
227234
command executor. Can also be set to None to disable certificate
228235
validation.
@@ -238,7 +245,7 @@ def set_certificate_bundle_path(cls, path):
238245
cls._client_config.ca_certs = path
239246

240247
@classmethod
241-
def get_remote_connection_headers(cls, parsed_url, keep_alive=False):
248+
def get_remote_connection_headers(cls, parsed_url: str, keep_alive: bool = False) -> dict[str, Any]:
242249
"""Get headers for remote request.
243250
244251
:Args:
@@ -309,7 +316,7 @@ def __init__(
309316
keep_alive: Optional[bool] = True,
310317
ignore_proxy: Optional[bool] = False,
311318
ignore_certificates: Optional[bool] = False,
312-
init_args_for_pool_manager: Optional[dict] = None,
319+
init_args_for_pool_manager: Optional[dict[Any, Any]] = None,
313320
client_config: Optional[ClientConfig] = None,
314321
):
315322
self._client_config = client_config or ClientConfig(
@@ -370,15 +377,15 @@ def __init__(
370377

371378
extra_commands = {}
372379

373-
def add_command(self, name, method, url):
380+
def add_command(self, name: str, method: str, url: str):
374381
"""Register a new command."""
375382
self._commands[name] = (method, url)
376383

377384
def get_command(self, name: str):
378385
"""Retrieve a command if it exists."""
379386
return self._commands.get(name)
380387

381-
def execute(self, command, params):
388+
def execute(self, command: str, params: dict[Any, Any]) -> dict[str, Any]:
382389
"""Send a command to the remote server.
383390
384391
Any path substitutions required for the URL mapped to the command should be
@@ -403,7 +410,7 @@ def execute(self, command, params):
403410
LOGGER.debug("%s %s %s", command_info[0], url, str(trimmed))
404411
return self._request(command_info[0], url, body=data)
405412

406-
def _request(self, method, url, body=None):
413+
def _request(self, method: str, url: str, body: str | None = None) -> dict[Any, Any]:
407414
"""Send an HTTP request to the remote server.
408415
409416
:Args:
@@ -470,7 +477,7 @@ def close(self):
470477
if hasattr(self, "_conn"):
471478
self._conn.clear()
472479

473-
def _trim_large_entries(self, input_dict, max_length=100):
480+
def _trim_large_entries(self, input_dict: dict[Any, Any], max_length: int = 100) -> dict[str, str]:
474481
"""Truncate string values in a dictionary if they exceed max_length.
475482
476483
:param dict: Dictionary with potentially large values

0 commit comments

Comments
 (0)