@@ -36,7 +36,7 @@ class SSHClient(object):
3636
3737 def __init__ (self , host ,
3838 user = None ,
39- password = None ):
39+ password = None , port = None ):
4040 """Connect to host honoring any user set configuration in ~/.ssh/config or /etc/ssh/ssh_config
4141 :type: str
4242 :param host: Hostname to connect to
@@ -68,16 +68,14 @@ def __init__(self, host,
6868 self .channel = None
6969 self .user = user
7070 self .password = password
71+ self .port = port if port else 22
7172 self .host = resolved_address
7273 self ._connect ()
7374
7475 def _connect (self ):
7576 """Connect to host, throw UnknownHost exception on DNS errors"""
7677 try :
77- if self .password :
78- self .client .connect (self .host , username = self .user , password = self .password )
79- else :
80- self .client .connect (self .host , username = self .user )
78+ self .client .connect (self .host , username = self .user , password = self .password , port = self .port )
8179 except socket .gaierror , e :
8280 logger .error ("Could not resolve host '%s'" % (self .host ,))
8381 raise UnknownHostException ("%s - %s" % (str (e .args [1 ]), self .host ,))
@@ -87,6 +85,7 @@ def _connect(self):
8785
8886 def exec_command (self , command , sudo = False , ** kwargs ):
8987 """Wrapper to paramiko.SSHClient.exec_command"""
88+
9089 channel = self .client .get_transport ().open_session ()
9190 channel .get_pty ()
9291 _ , stdout , stderr = channel .makefile ('wb' ), channel .makefile ('rb' ), channel .makefile_stderr ('rb' )
@@ -103,7 +102,7 @@ class ParallelSSHClient(object):
103102 """Uses SSHClient, runs command on multiple hosts in parallel"""
104103
105104 def __init__ (self , hosts ,
106- user = None , password = None ,
105+ user = None , password = None , port = None ,
107106 pool_size = 10 ):
108107 """Connect to hosts
109108 :type: list(str)
@@ -120,16 +119,21 @@ def __init__(self, hosts,
120119 self .pool = gevent .pool .Pool (size = pool_size )
121120 self .pool_size = pool_size
122121 self .hosts = hosts
123-
124- # Initialise connections to all hosts
125- self .host_clients = dict ((host , SSHClient (host , user = user , password = password )) for host in hosts )
122+ self .user = user
123+ self .password = password
124+ self .port = port
125+ # To hold host clients
126+ self .host_clients = dict ((host , None ) for host in hosts )
126127
127128 def exec_command (self , * args , ** kwargs ):
128129 """Run command on all hosts in parallel, honoring self.pool_size"""
129130 return [self .pool .spawn (self ._exec_command , host , * args , ** kwargs ) for host in self .hosts ]
130131
131132 def _exec_command (self , host , * args , ** kwargs ):
132133 """Make SSHClient, run command on host"""
134+ if not self .host_clients [host ]:
135+ self .host_clients [host ] = SSHClient (host , user = self .user , password = self .password ,
136+ port = self .port )
133137 return self .host_clients [host ].exec_command (* args , ** kwargs )
134138
135139 def get_stdout (self , greenlet ):
0 commit comments