1616host_log_format = logging .Formatter ('%(message)s' )
1717handler .setFormatter (host_log_format )
1818host_logger .addHandler (handler )
19- host_logger .setLevel (logging .DEBUG )
19+ host_logger .setLevel (logging .INFO )
2020
2121logger = logging .getLogger (__name__ )
2222
@@ -29,7 +29,7 @@ def _setup_logger(_logger):
2929 _logger .addHandler (handler )
3030 _logger .setLevel (logging .DEBUG )
3131
32-
32+
3333class UnknownHostException (Exception ):
3434 """Raised when a host is unknown (dns failure)"""
3535 pass
@@ -49,8 +49,7 @@ class SSHClient(object):
4949 """Wrapper class over paramiko.SSHClient with sane defaults"""
5050
5151 def __init__ (self , host ,
52- user = None ,
53- password = None , port = None ):
52+ user = None , password = None , port = None ):
5453 """Connect to host honoring any user set configuration in ~/.ssh/config
5554 or /etc/ssh/ssh_config
5655 :type: str
@@ -121,6 +120,42 @@ def exec_command(self, command, sudo=False, **kwargs):
121120 gevent .sleep (.2 )
122121 return channel , self .host , stdout , stderr
123122
123+ def _make_sftp (self ):
124+ """Make SFTP client from open transport"""
125+ transport = self .client .get_transport ()
126+ channel = transport .open_session ()
127+ return paramiko .SFTPClient .from_transport (transport )
128+
129+ def mkdir (self , sftp , directory ):
130+ """Make directory via SFTP channel"""
131+ try :
132+ sftp .mkdir (directory )
133+ except IOError , error :
134+ logger .error ("Error occured creating directory on %s - %s" ,
135+ self .host , error )
136+
137+ def copy_file (self , local_file , remote_file ):
138+ """Copy local file to host via SFTP"""
139+ sftp = self ._make_sftp ()
140+ destination = remote_file .split (os .path .sep )
141+ filename = destination [0 ] if len (destination ) == 1 else destination [- 1 ]
142+ remote_file = os .path .sep .join (destination )
143+ destination = destination [:- 1 ]
144+ # import ipdb; ipdb.set_trace()
145+ for directory in destination :
146+ try :
147+ sftp .stat (directory )
148+ except IOError :
149+ self .mkdir (sftp , directory )
150+ try :
151+ sftp .put (local_file , remote_file )
152+ except Exception , error :
153+ logger .error ("Error occured copying file to host %s - %s" ,
154+ self .host , error )
155+ else :
156+ logger .info ("Copied local file %s to remote destination %s:%s" ,
157+ local_file , self .host , remote_file )
158+
124159class ParallelSSHClient (object ):
125160 """Uses SSHClient, runs command on multiple hosts in parallel"""
126161
@@ -153,7 +188,7 @@ def exec_command(self, *args, **kwargs):
153188 """Run command on all hosts in parallel, honoring self.pool_size"""
154189 return [self .pool .spawn (self ._exec_command , host , * args , ** kwargs )
155190 for host in self .hosts ]
156-
191+
157192 def _exec_command (self , host , * args , ** kwargs ):
158193 """Make SSHClient, run command on host"""
159194 if not self .host_clients [host ]:
@@ -172,13 +207,35 @@ def get_stdout(self, greenlet):
172207 channel .close ()
173208 return {host : {'exit_code' : channel .recv_exit_status ()}}
174209
210+ def copy_file (self , local_file , remote_file ):
211+ """Copy local file to remote file in parallel"""
212+ return [self .pool .spawn (self ._copy_file , host , local_file , remote_file )
213+ for host in self .hosts ]
175214
215+ def _copy_file (self , host , local_file , remote_file ):
216+ """Make sftp client, copy file"""
217+ if not self .host_clients [host ]:
218+ self .host_clients [host ] = SSHClient (host , user = self .user ,
219+ password = self .password ,
220+ port = self .port )
221+ return self .host_clients [host ].copy_file (local_file , remote_file )
222+
223+
176224def test ():
177225 client = SSHClient ('localhost' )
178226 channel , host , stdout , stderr = client .exec_command ('ls -ltrh' )
179227 for line in stdout :
180228 print line .strip ()
229+ client .copy_file ('../test' , 'test_dir/test' )
230+
231+ def test_parallel ():
232+ client = ParallelSSHClient (['localhost' ])
233+ cmds = client .exec_command ('ls -ltrh' )
234+ print [client .get_stdout (cmd ) for cmd in cmds ]
235+ cmds = client .copy_file ('../test' , 'test_dir/test' )
236+ client .pool .join ()
181237
182238if __name__ == "__main__" :
183239 _setup_logger (logger )
184240 test ()
241+ test_parallel ()
0 commit comments