3434
3535__all__ = ["Connection" ]
3636
37- logging .getLogger (__name__ )
37+ log = logging .getLogger (__name__ )
3838
3939# guard for when readthedocs is building documentation or travis
4040# is running CI build
@@ -121,21 +121,22 @@ class Connection(metaclass=_ConnectionMeta):
121121
122122 @overload
123123 def __new__ (cls , ssh_server : str , local : Literal [False ], quiet : bool ,
124- thread_safe : bool ) -> SSHConnection :
124+ thread_safe : bool , allow_agent : bool ) -> SSHConnection :
125125 ...
126126
127127 @overload
128128 def __new__ (cls , ssh_server : str , local : Literal [True ], quiet : bool ,
129- thread_safe : bool ) -> LocalConnection :
129+ thread_safe : bool , allow_agent : bool ) -> LocalConnection :
130130 ...
131131
132132 @overload
133133 def __new__ (cls , ssh_server : str , local : bool , quiet : bool ,
134- thread_safe : bool ) -> Union [SSHConnection , LocalConnection ]:
134+ thread_safe : bool , allow_agent : bool
135+ ) -> Union [SSHConnection , LocalConnection ]:
135136 ...
136137
137138 def __new__ (cls , ssh_server : str , local : bool = False , quiet : bool = False ,
138- thread_safe : bool = False ):
139+ thread_safe : bool = False , allow_agent : bool = True ):
139140 """Get Connection based on one of names defined in .ssh/config file.
140141
141142 If name of local PC is passed initilize LocalConnection
@@ -152,11 +153,14 @@ def __new__(cls, ssh_server: str, local: bool = False, quiet: bool = False,
152153 make connection object thread safe so it can be safely accessed
153154 from any number of threads, it is disabled by default to avoid
154155 performance penalty of threading locks
156+ allow_agent: bool
157+ allows use of ssh agent for connection authentication, when this is
158+ `True` key for the host does not have to be available.
155159
156160 Raises
157161 ------
158162 KeyError
159- if server name is not in config file
163+ if server name is not in config file and allow agent is false
160164
161165 Returns
162166 -------
@@ -173,14 +177,36 @@ def __new__(cls, ssh_server: str, local: bool = False, quiet: bool = False,
173177 raise KeyError (f"couldn't find login credentials for { ssh_server } :"
174178 f" { e } " )
175179 else :
180+ # get username and address
176181 try :
177- return cls .open (credentials ["user" ], credentials ["hostname" ],
178- credentials ["identityfile" ][0 ],
179- server_name = ssh_server , quiet = quiet ,
180- thread_safe = thread_safe )
182+ user = credentials ["user" ]
183+ hostname = credentials ["hostname" ]
181184 except KeyError as e :
182- raise KeyError (f"{ RED } missing key in config dictionary for "
183- f"{ ssh_server } : { R } { e } " )
185+ raise KeyError (
186+ "Cannot find username or hostname for specified host"
187+ )
188+
189+ # get key or use agent
190+ if allow_agent :
191+ log .info (f"no private key supplied for { hostname } , will try "
192+ f"to authenticate through ssh-agent" )
193+ pkey_file = None
194+ else :
195+ log .info (f"private key found for host: { hostname } " )
196+ try :
197+ pkey_file = credentials ["identityfile" ][0 ]
198+ except (KeyError , IndexError ) as e :
199+ raise KeyError (f"No private key found for specified host" )
200+
201+ return cls .open (
202+ user ,
203+ hostname ,
204+ ssh_key_file = pkey_file ,
205+ allow_agent = allow_agent ,
206+ server_name = ssh_server ,
207+ quiet = quiet ,
208+ thread_safe = thread_safe
209+ )
184210
185211 @classmethod
186212 def get_available_hosts (cls ) -> List [str ]:
@@ -212,7 +238,8 @@ def get(cls, *args, **kwargs):
212238 get_connection = get
213239
214240 @classmethod
215- def add_hosts (cls , hosts : Union ["_HOSTS" , List ["_HOSTS" ]]):
241+ def add_hosts (cls , hosts : Union ["_HOSTS" , List ["_HOSTS" ]],
242+ allow_agent : Union [bool , List [bool ]]):
216243 """Add or override availbale host read fron ssh config file.
217244
218245 You can use supplied config parser to parse some externaf ssh config
@@ -223,15 +250,22 @@ def add_hosts(cls, hosts: Union["_HOSTS", List["_HOSTS"]]):
223250 hosts : Union[_HOSTS, List[_HOSTS]]
224251 dictionary or a list of dictionaries containing keys: `user`,
225252 `hostname` and `identityfile`
253+ allow_agent: Union[bool, List[bool]]
254+ bool or a list of bools with corresponding length to list of hosts.
255+ if only one bool is passed in, it will be used for all host entries
226256
227257 See also
228258 --------
229259 :func:ssh_utilities.config_parser
230260 """
231261 if not isinstance (hosts , list ):
232262 hosts = [hosts ]
263+ if not isinstance (allow_agent , list ):
264+ allow_agent = [allow_agent ] * len (hosts )
233265
234- for h in hosts :
266+ for h , a in zip (hosts , allow_agent ):
267+ if a :
268+ h ["identityfile" ][0 ] = None
235269 if not isinstance (h ["identityfile" ], list ):
236270 h ["identityfile" ] = [h ["identityfile" ]]
237271 h ["identityfile" ][0 ] = os .path .abspath (
@@ -300,7 +334,7 @@ def open(ssh_username: str, ssh_server: None = None,
300334 ssh_password : Optional [str ] = None ,
301335 server_name : Optional [str ] = None , quiet : bool = False ,
302336 thread_safe : bool = False ,
303- ssh_allow_agent : bool = False ) -> LocalConnection :
337+ allow_agent : bool = False ) -> LocalConnection :
304338 ...
305339
306340 @overload
@@ -310,7 +344,7 @@ def open(ssh_username: str, ssh_server: str,
310344 ssh_password : Optional [str ] = None ,
311345 server_name : Optional [str ] = None , quiet : bool = False ,
312346 thread_safe : bool = False ,
313- ssh_allow_agent : bool = False ) -> SSHConnection :
347+ allow_agent : bool = False ) -> SSHConnection :
314348 ...
315349
316350 @staticmethod
@@ -319,7 +353,7 @@ def open(ssh_username: str, ssh_server: Optional[str] = "",
319353 ssh_password : Optional [str ] = None ,
320354 server_name : Optional [str ] = None , quiet : bool = False ,
321355 thread_safe : bool = False ,
322- ssh_allow_agent : bool = False ):
356+ allow_agent : bool = False ):
323357 """Initialize SSH or local connection.
324358
325359 Local connection is only a wrapper around os and shutil module methods
@@ -346,7 +380,7 @@ def open(ssh_username: str, ssh_server: Optional[str] = "",
346380 make connection object thread safe so it can be safely accessed
347381 from any number of threads, it is disabled by default to avoid
348382 performance penalty of threading locks
349- ssh_allow_agent : bool
383+ allow_agent : bool
350384 allow the use of the ssh-agent to connect. Will disable ssh_key_file.
351385
352386 Warnings
@@ -355,27 +389,29 @@ def open(ssh_username: str, ssh_server: Optional[str] = "",
355389 risk!
356390 """
357391 if not ssh_server :
358- return LocalConnection (ssh_server , ssh_username ,
359- pkey_file = ssh_key_file ,
360- server_name = server_name , quiet = quiet )
361- else :
362- if ssh_allow_agent :
363- c = SSHConnection (ssh_server , ssh_username ,
364- allow_agent = ssh_allow_agent , line_rewrite = True ,
365- server_name = server_name , quiet = quiet ,
366- thread_safe = thread_safe )
367- elif ssh_key_file :
368- c = SSHConnection (ssh_server , ssh_username ,
369- pkey_file = ssh_key_file , line_rewrite = True ,
370- server_name = server_name , quiet = quiet ,
371- thread_safe = thread_safe )
372- else :
373- if not ssh_password :
374- ssh_password = getpass .getpass (prompt = "Enter password: " )
375-
376- c = SSHConnection (ssh_server , ssh_username ,
377- password = ssh_password , line_rewrite = True ,
378- server_name = server_name , quiet = quiet ,
379- thread_safe = thread_safe )
380-
381- return c
392+ return LocalConnection (
393+ ssh_server ,
394+ ssh_username ,
395+ pkey_file = ssh_key_file ,
396+ server_name = server_name ,
397+ quiet = quiet
398+ )
399+ elif allow_agent :
400+ ssh_key_file = None
401+ ssh_password = None
402+ elif ssh_key_file :
403+ ssh_password = None
404+ elif not ssh_password :
405+ ssh_password = getpass .getpass (prompt = "Enter password: " )
406+
407+ return SSHConnection (
408+ ssh_server ,
409+ ssh_username ,
410+ allow_agent = allow_agent ,
411+ pkey_file = ssh_key_file ,
412+ password = ssh_password ,
413+ line_rewrite = True ,
414+ server_name = server_name ,
415+ quiet = quiet ,
416+ thread_safe = thread_safe
417+ )
0 commit comments