Skip to content

Commit 23b7d6c

Browse files
committed
First commit - added pssh library and tests
1 parent 0356c2e commit 23b7d6c

File tree

5 files changed

+187
-0
lines changed

5 files changed

+187
-0
lines changed

.travis.yml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
language: python
2+
python:
3+
- "2.5"
4+
- "2.6"
5+
- "2.7"
6+
install:
7+
- pip install -r requirements.txt --use-mirrors
8+
script: nosetests
9+
notifications:
10+
email:
11+
on_failure: change

pssh.py

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
#!/usr/bin/env python
2+
3+
"""Module containing wrapper classes over paramiko.SSHClient
4+
See SSHClient and ParallelSSHClient"""
5+
6+
import paramiko
7+
import os
8+
import gevent
9+
import gevent.pool
10+
from gevent import monkey
11+
monkey.patch_all()
12+
import logging
13+
import socket
14+
15+
host_logger = logging.getLogger('host_logging')
16+
handler = logging.StreamHandler()
17+
host_log_format = logging.Formatter('%(message)s')
18+
handler.setFormatter(host_log_format)
19+
host_logger.addHandler(handler)
20+
host_logger.setLevel(logging.DEBUG)
21+
22+
logger = logging.getLogger(__name__)
23+
24+
def _setup_logger(_logger):
25+
"""Setup default logger"""
26+
handler = logging.StreamHandler()
27+
log_format = logging.Formatter('%(name)s - %(asctime)s - %(levelname)s - %(message)s')
28+
handler.setFormatter(log_format)
29+
_logger.addHandler(handler)
30+
_logger.setLevel(logging.DEBUG)
31+
32+
class UnknownHostException(Exception): pass
33+
class ConnectionErrorException(Exception): pass
34+
35+
class SSHClient(object):
36+
"""Wrapper class over paramiko.SSHClient with sane defaults"""
37+
38+
def __init__(self, host,
39+
user = None):
40+
"""Connect to host honoring any user set configuration in ~/.ssh/config or /etc/ssh/ssh_config
41+
:type: str
42+
:param host: Hostname to connect to
43+
:throws: paramiko.AuthenticationException on authentication error
44+
:throws: ssh_client.UnknownHostException on DNS resolution error
45+
:throws: ssh_client.ConnectionErrorException on error connecting"""
46+
ssh_config = paramiko.SSHConfig()
47+
_ssh_config_file = os.path.sep.join([os.path.expanduser('~'),
48+
'.ssh',
49+
'config'])
50+
# Load ~/.ssh/config if it exists to pick up username
51+
# and host address if set
52+
if os.path.isfile(_ssh_config_file):
53+
ssh_config.parse(open(_ssh_config_file))
54+
host_config = ssh_config.lookup(host)
55+
resolved_address = (host_config['hostname'] if
56+
'hostname' in host_config
57+
else host)
58+
_user = host_config['user'] if 'user' in host_config else None
59+
if user:
60+
user = user
61+
else:
62+
user = _user
63+
client = paramiko.SSHClient()
64+
client.set_missing_host_key_policy(paramiko.MissingHostKeyPolicy())
65+
self.client = client
66+
self.channel = None
67+
self.user = user
68+
self.host = resolved_address
69+
self._connect()
70+
71+
def _connect(self):
72+
"""Connect to host, throw UnknownHost exception on DNS errors"""
73+
try:
74+
self.client.connect(self.host, username = self.user)
75+
except socket.gaierror as e:
76+
logger.error("Could not resolve host %s" % (self.host,))
77+
raise UnknownHostException("%s - %s" % (str(e.strerror), self.host,))
78+
except socket.error as e:
79+
logger.error("Error connecting to host %s" % (self.host,))
80+
raise ConnectionErrorException("%s for host '%s'" % (str(e.strerror), self.host,))
81+
82+
def exec_command(self, command, sudo = False, **kwargs):
83+
"""Wrapper to paramiko.SSHClient.exec_command"""
84+
channel = self.client.get_transport().open_session()
85+
channel.get_pty()
86+
_, stdout, stderr = channel.makefile('wb'), channel.makefile('rb'), channel.makefile_stderr('rb')
87+
if sudo:
88+
command = 'sudo -S bash -c "%s"' % command.replace('"', '\\"')
89+
logger.debug("Running command %s on %s" % (command, self.host))
90+
channel.exec_command(command, **kwargs)
91+
logger.debug("Command finished executing")
92+
while not channel.recv_ready():
93+
gevent.sleep(.2)
94+
return channel, self.host, stdout, stderr
95+
96+
class ParallelSSHClient(object):
97+
"""Uses SSHClient, runs command on multiple hosts in parallel"""
98+
99+
def __init__(self, hosts, pool_size = 10,
100+
user = None):
101+
"""Connect to hosts
102+
:type: list(str)
103+
:param hosts: Hosts to connect to
104+
:type: int
105+
:param pool_size: Pool size - how many commands to run in parallel
106+
:throws: paramiko.AuthenticationException on authentication error
107+
:throws: ssh_client.UnknownHostException on DNS resolution error"""
108+
self.pool = gevent.pool.Pool(size = pool_size)
109+
self.pool_size = pool_size
110+
self.hosts = hosts
111+
self.user = user
112+
# Initialise connections to all hosts
113+
self.host_clients = dict((host, SSHClient(host, user = user)) for host in hosts)
114+
115+
def exec_command(self, *args, **kwargs):
116+
"""Run command on all hosts in parallel, honoring self.pool_size"""
117+
return [self.pool.spawn(self._exec_command, host, *args, **kwargs) for host in self.hosts]
118+
119+
def _exec_command(self, host, *args, **kwargs):
120+
"""Make SSHClient, run command on host"""
121+
return self.host_clients[host].exec_command(*args, **kwargs)
122+
123+
def get_stdout(self, greenlet):
124+
"""Print stdout from greenlet and return exit code for host"""
125+
channel, host, stdout, stderr = greenlet.get()
126+
for line in stdout:
127+
host_logger.info("[%s]\t%s" % (host, line.strip(),))
128+
for line in stderr:
129+
host_logger.info("[%s] [err] %s" % (host, line.strip(),))
130+
channel.close()
131+
return {host : {'exit_code' : channel.recv_exit_status()}}
132+
133+
def test():
134+
client = SSHClient('localhost')
135+
channel, host, stdout, stderr = client.exec_command('ls -ltrh')
136+
for line in stdout:
137+
print line.strip()
138+
139+
if __name__ == "__main__":
140+
_setup_logger(logger)
141+
test()

requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
paramiko
2+
gevent

setup.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from distutils.core import setup
2+
from setuptools import find_packages
3+
4+
setup(name='parallel-ssh',
5+
version='0.1',
6+
description='Wrapper library over paramiko to allow remote execution of tasks. Supports parallel execution on multiple hosts',
7+
author='Panos Kittenis',
8+
author_email='[email protected]',
9+
packages = find_packages('.'),
10+
install_requires = open('requirements.txt').readlines(),
11+
)

tests/test_ssh_client.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#!/usr/bin/env python2.7
2+
3+
"""Unittests for verrot webapp"""
4+
5+
import unittest
6+
from pssh import SSHClient, ParallelSSHClient, UnknownHostException
7+
from paramiko import AuthenticationException
8+
9+
class SSHClientTest(unittest.TestCase):
10+
11+
def test_ssh_client(self):
12+
try:
13+
client = SSHClient('testy')
14+
except UnknownHostException as e:
15+
print e
16+
return
17+
stdin, stdout, stderr = client.exec_command('ls -ltrh')
18+
for line in stdout:
19+
print line.strip()
20+
21+
if __name__ == '__main__':
22+
unittest.main()

0 commit comments

Comments
 (0)