Skip to content

Commit 5a6e99f

Browse files
committed
Initialise connections in parallel to speed up running on many hosts
1 parent 062d7f1 commit 5a6e99f

File tree

1 file changed

+13
-9
lines changed

1 file changed

+13
-9
lines changed

pssh.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ class SSHClient(object):
3636

3737
def __init__(self, host,
3838
user = None,
39-
password = None):
39+
password = None, port = None):
4040
"""Connect to host honoring any user set configuration in ~/.ssh/config or /etc/ssh/ssh_config
4141
:type: str
4242
:param host: Hostname to connect to
@@ -68,16 +68,14 @@ def __init__(self, host,
6868
self.channel = None
6969
self.user = user
7070
self.password = password
71+
self.port = port if port else 22
7172
self.host = resolved_address
7273
self._connect()
7374

7475
def _connect(self):
7576
"""Connect to host, throw UnknownHost exception on DNS errors"""
7677
try:
77-
if self.password:
78-
self.client.connect(self.host, username=self.user, password=self.password)
79-
else:
80-
self.client.connect(self.host, username=self.user)
78+
self.client.connect(self.host, username=self.user, password=self.password, port = self.port)
8179
except socket.gaierror, e:
8280
logger.error("Could not resolve host '%s'" % (self.host,))
8381
raise UnknownHostException("%s - %s" % (str(e.args[1]), self.host,))
@@ -87,6 +85,7 @@ def _connect(self):
8785

8886
def exec_command(self, command, sudo = False, **kwargs):
8987
"""Wrapper to paramiko.SSHClient.exec_command"""
88+
9089
channel = self.client.get_transport().open_session()
9190
channel.get_pty()
9291
_, stdout, stderr = channel.makefile('wb'), channel.makefile('rb'), channel.makefile_stderr('rb')
@@ -103,7 +102,7 @@ class ParallelSSHClient(object):
103102
"""Uses SSHClient, runs command on multiple hosts in parallel"""
104103

105104
def __init__(self, hosts,
106-
user = None, password = None,
105+
user = None, password = None, port = None,
107106
pool_size = 10):
108107
"""Connect to hosts
109108
:type: list(str)
@@ -120,16 +119,21 @@ def __init__(self, hosts,
120119
self.pool = gevent.pool.Pool(size = pool_size)
121120
self.pool_size = pool_size
122121
self.hosts = hosts
123-
124-
# Initialise connections to all hosts
125-
self.host_clients = dict((host, SSHClient(host, user = user, password = password)) for host in hosts)
122+
self.user = user
123+
self.password = password
124+
self.port = port
125+
# To hold host clients
126+
self.host_clients = dict((host, None) for host in hosts)
126127

127128
def exec_command(self, *args, **kwargs):
128129
"""Run command on all hosts in parallel, honoring self.pool_size"""
129130
return [self.pool.spawn(self._exec_command, host, *args, **kwargs) for host in self.hosts]
130131

131132
def _exec_command(self, host, *args, **kwargs):
132133
"""Make SSHClient, run command on host"""
134+
if not self.host_clients[host]:
135+
self.host_clients[host] = SSHClient(host, user = self.user, password = self.password,
136+
port = self.port)
133137
return self.host_clients[host].exec_command(*args, **kwargs)
134138

135139
def get_stdout(self, greenlet):

0 commit comments

Comments
 (0)