2020import string
2121import warnings
2222from base64 import b64encode
23+ from typing import Any
2324from typing import Optional
25+ from typing import TypeVar
2426from urllib import parse
2527from urllib .parse import urlparse
2628
3537
3638LOGGER = logging .getLogger (__name__ )
3739
40+ # TODO: Replace with 'Self' when Python 3.11+ is supported.
41+ # from typing import Self
42+
43+ RemoteConnectionType = TypeVar ("RemoteConnectionType" , bound = "RemoteConnection" )
44+
3845remote_commands = {
3946 Command .NEW_SESSION : ("POST" , "/session" ),
4047 Command .QUIT : ("DELETE" , "/session/$sessionId" ),
@@ -154,7 +161,7 @@ class RemoteConnection:
154161
155162 _timeout = socket .getdefaulttimeout ()
156163 _ca_certs = os .getenv ("REQUESTS_CA_BUNDLE" ) if "REQUESTS_CA_BUNDLE" in os .environ else certifi .where ()
157- _client_config : ClientConfig = None
164+ _client_config : ClientConfig | None = None
158165
159166 system = platform .system ().lower ()
160167 if system == "darwin" :
@@ -169,7 +176,7 @@ def client_config(self):
169176 return self ._client_config
170177
171178 @classmethod
172- def get_timeout (cls ):
179+ def get_timeout (cls ) -> float | int | None :
173180 """:Returns:
174181
175182 Timeout value in seconds for all http requests made to the
@@ -183,7 +190,7 @@ def get_timeout(cls):
183190 return cls ._client_config .timeout
184191
185192 @classmethod
186- def set_timeout (cls , timeout ):
193+ def set_timeout (cls , timeout : int | float ):
187194 """Override the default timeout.
188195
189196 :Args:
@@ -207,7 +214,7 @@ def reset_timeout(cls):
207214 cls ._client_config .reset_timeout ()
208215
209216 @classmethod
210- def get_certificate_bundle_path (cls ):
217+ def get_certificate_bundle_path (cls ) -> str :
211218 """:Returns:
212219
213220 Paths of the .pem encoded certificate to verify connection to
@@ -222,7 +229,7 @@ def get_certificate_bundle_path(cls):
222229 return cls ._client_config .ca_certs
223230
224231 @classmethod
225- def set_certificate_bundle_path (cls , path ):
232+ def set_certificate_bundle_path (cls , path : str ):
226233 """Set the path to the certificate bundle to verify connection to
227234 command executor. Can also be set to None to disable certificate
228235 validation.
@@ -238,7 +245,7 @@ def set_certificate_bundle_path(cls, path):
238245 cls ._client_config .ca_certs = path
239246
240247 @classmethod
241- def get_remote_connection_headers (cls , parsed_url , keep_alive = False ):
248+ def get_remote_connection_headers (cls , parsed_url : str , keep_alive : bool = False ) -> dict [ str , Any ] :
242249 """Get headers for remote request.
243250
244251 :Args:
@@ -309,7 +316,7 @@ def __init__(
309316 keep_alive : Optional [bool ] = True ,
310317 ignore_proxy : Optional [bool ] = False ,
311318 ignore_certificates : Optional [bool ] = False ,
312- init_args_for_pool_manager : Optional [dict ] = None ,
319+ init_args_for_pool_manager : Optional [dict [ Any , Any ] ] = None ,
313320 client_config : Optional [ClientConfig ] = None ,
314321 ):
315322 self ._client_config = client_config or ClientConfig (
@@ -370,15 +377,15 @@ def __init__(
370377
371378 extra_commands = {}
372379
373- def add_command (self , name , method , url ):
380+ def add_command (self , name : str , method : str , url : str ):
374381 """Register a new command."""
375382 self ._commands [name ] = (method , url )
376383
377384 def get_command (self , name : str ):
378385 """Retrieve a command if it exists."""
379386 return self ._commands .get (name )
380387
381- def execute (self , command , params ) :
388+ def execute (self , command : str , params : dict [ Any , Any ]) -> dict [ str , Any ] :
382389 """Send a command to the remote server.
383390
384391 Any path substitutions required for the URL mapped to the command should be
@@ -403,7 +410,7 @@ def execute(self, command, params):
403410 LOGGER .debug ("%s %s %s" , command_info [0 ], url , str (trimmed ))
404411 return self ._request (command_info [0 ], url , body = data )
405412
406- def _request (self , method , url , body = None ):
413+ def _request (self , method : str , url : str , body : str | None = None ) -> dict [ Any , Any ] :
407414 """Send an HTTP request to the remote server.
408415
409416 :Args:
@@ -470,7 +477,7 @@ def close(self):
470477 if hasattr (self , "_conn" ):
471478 self ._conn .clear ()
472479
473- def _trim_large_entries (self , input_dict , max_length = 100 ):
480+ def _trim_large_entries (self , input_dict : dict [ Any , Any ], max_length : int = 100 ) -> dict [ str , str ] :
474481 """Truncate string values in a dictionary if they exceed max_length.
475482
476483 :param dict: Dictionary with potentially large values
0 commit comments