Skip to content

Commit ae78e65

Browse files
committed
Added SFTP support for sending files to remote hosts. Includes directory support between different OSs and creating directories if they do not exist remotely. Closes #2
1 parent 1141a1a commit ae78e65

File tree

2 files changed

+76
-5
lines changed

2 files changed

+76
-5
lines changed

pssh.py

Lines changed: 62 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
host_log_format = logging.Formatter('%(message)s')
1717
handler.setFormatter(host_log_format)
1818
host_logger.addHandler(handler)
19-
host_logger.setLevel(logging.DEBUG)
19+
host_logger.setLevel(logging.INFO)
2020

2121
logger = logging.getLogger(__name__)
2222

@@ -29,7 +29,7 @@ def _setup_logger(_logger):
2929
_logger.addHandler(handler)
3030
_logger.setLevel(logging.DEBUG)
3131

32-
32+
3333
class UnknownHostException(Exception):
3434
"""Raised when a host is unknown (dns failure)"""
3535
pass
@@ -49,8 +49,7 @@ class SSHClient(object):
4949
"""Wrapper class over paramiko.SSHClient with sane defaults"""
5050

5151
def __init__(self, host,
52-
user=None,
53-
password=None, port=None):
52+
user=None, password=None, port=None):
5453
"""Connect to host honoring any user set configuration in ~/.ssh/config
5554
or /etc/ssh/ssh_config
5655
:type: str
@@ -121,6 +120,42 @@ def exec_command(self, command, sudo=False, **kwargs):
121120
gevent.sleep(.2)
122121
return channel, self.host, stdout, stderr
123122

123+
def _make_sftp(self):
124+
"""Make SFTP client from open transport"""
125+
transport = self.client.get_transport()
126+
channel = transport.open_session()
127+
return paramiko.SFTPClient.from_transport(transport)
128+
129+
def mkdir(self, sftp, directory):
130+
"""Make directory via SFTP channel"""
131+
try:
132+
sftp.mkdir(directory)
133+
except IOError, error:
134+
logger.error("Error occured creating directory on %s - %s",
135+
self.host, error)
136+
137+
def copy_file(self, local_file, remote_file):
138+
"""Copy local file to host via SFTP"""
139+
sftp = self._make_sftp()
140+
destination = remote_file.split(os.path.sep)
141+
filename = destination[0] if len(destination) == 1 else destination[-1]
142+
remote_file = os.path.sep.join(destination)
143+
destination = destination[:-1]
144+
# import ipdb; ipdb.set_trace()
145+
for directory in destination:
146+
try:
147+
sftp.stat(directory)
148+
except IOError:
149+
self.mkdir(sftp, directory)
150+
try:
151+
sftp.put(local_file, remote_file)
152+
except Exception, error:
153+
logger.error("Error occured copying file to host %s - %s",
154+
self.host, error)
155+
else:
156+
logger.info("Copied local file %s to remote destination %s:%s",
157+
local_file, self.host, remote_file)
158+
124159
class ParallelSSHClient(object):
125160
"""Uses SSHClient, runs command on multiple hosts in parallel"""
126161

@@ -153,7 +188,7 @@ def exec_command(self, *args, **kwargs):
153188
"""Run command on all hosts in parallel, honoring self.pool_size"""
154189
return [self.pool.spawn(self._exec_command, host, *args, **kwargs)
155190
for host in self.hosts]
156-
191+
157192
def _exec_command(self, host, *args, **kwargs):
158193
"""Make SSHClient, run command on host"""
159194
if not self.host_clients[host]:
@@ -172,13 +207,35 @@ def get_stdout(self, greenlet):
172207
channel.close()
173208
return {host: {'exit_code': channel.recv_exit_status()}}
174209

210+
def copy_file(self, local_file, remote_file):
211+
"""Copy local file to remote file in parallel"""
212+
return [self.pool.spawn(self._copy_file, host, local_file, remote_file)
213+
for host in self.hosts]
175214

215+
def _copy_file(self, host, local_file, remote_file):
216+
"""Make sftp client, copy file"""
217+
if not self.host_clients[host]:
218+
self.host_clients[host] = SSHClient(host, user=self.user,
219+
password=self.password,
220+
port=self.port)
221+
return self.host_clients[host].copy_file(local_file, remote_file)
222+
223+
176224
def test():
177225
client = SSHClient('localhost')
178226
channel, host, stdout, stderr = client.exec_command('ls -ltrh')
179227
for line in stdout:
180228
print line.strip()
229+
client.copy_file('../test', 'test_dir/test')
230+
231+
def test_parallel():
232+
client = ParallelSSHClient(['localhost'])
233+
cmds = client.exec_command('ls -ltrh')
234+
print [client.get_stdout(cmd) for cmd in cmds]
235+
cmds = client.copy_file('../test', 'test_dir/test')
236+
client.pool.join()
181237

182238
if __name__ == "__main__":
183239
_setup_logger(logger)
184240
test()
241+
test_parallel()

tests/test_ssh_client.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,20 @@ def test_ssh_client(self):
1717
stdin, stdout, stderr = client.exec_command('ls -ltrh')
1818
for line in stdout:
1919
print line.strip()
20+
client.copy_file("fake file", "fake file")
21+
22+
class ParallelSSHClientTest(unittest.TestCase):
23+
24+
def test_parallel_ssh_client(self):
25+
client = ParallelSSHClient(['testy'])
26+
cmds = client.exec_command('ls -ltrh')
27+
try:
28+
print [client.get_stdout(cmd) for cmd in cmds]
29+
except UnknownHostException, e:
30+
print e
31+
return
32+
cmds = client.copy_file('fake file', 'fake file')
33+
client.pool.join()
2034

2135
if __name__ == '__main__':
2236
unittest.main()

0 commit comments

Comments
 (0)